diff --git a/Cargo.lock b/Cargo.lock index 432d8283f7..62e59d0146 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2137,6 +2137,7 @@ dependencies = [ "qsc_doc_gen", "qsc_eval", "qsc_fir", + "qsc_fir_transforms", "qsc_formatter", "qsc_frontend", "qsc_hir", @@ -2274,6 +2275,32 @@ dependencies = [ "rustc-hash", ] +[[package]] +name = "qsc_fir_transforms" +version = "0.0.0" +dependencies = [ + "expect-test", + "indoc", + "miette", + "num-bigint", + "proptest", + "qsc_codegen", + "qsc_data_structures", + "qsc_eval", + "qsc_fir", + "qsc_fir_transforms", + "qsc_formatter", + "qsc_frontend", + "qsc_hir", + "qsc_lowerer", + "qsc_parse", + "qsc_partial_eval", + "qsc_passes", + "qsc_rca", + "rustc-hash", + "thiserror", +] + [[package]] name = "qsc_formatter" version = "0.0.0" @@ -2410,6 +2437,7 @@ dependencies = [ "qsc_fir", "qsc_frontend", "qsc_lowerer", + "qsc_passes", "qsc_rca", "qsc_rir", "rustc-hash", @@ -2468,6 +2496,7 @@ dependencies = [ "qsc", "qsc_data_structures", "qsc_fir", + "qsc_fir_transforms", "qsc_frontend", "qsc_lowerer", "qsc_passes", diff --git a/Cargo.toml b/Cargo.toml index 688b27457b..888aed196a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,7 @@ members = [ "source/compiler/qsc_doc_gen", "source/compiler/qsc_eval", "source/compiler/qsc_fir", + "source/compiler/qsc_fir_transforms", "source/compiler/qsc_frontend", "source/compiler/qsc_hir", "source/compiler/qsc_openqasm_compiler", diff --git a/source/compiler/qsc/Cargo.toml b/source/compiler/qsc/Cargo.toml index ae22ffb526..912651f26e 100644 --- a/source/compiler/qsc/Cargo.toml +++ b/source/compiler/qsc/Cargo.toml @@ -26,6 +26,7 @@ qsc_linter = { path = "../qsc_linter" } qsc_lowerer = { path = "../qsc_lowerer" } qsc_ast = { path = "../qsc_ast" } qsc_fir = { path = "../qsc_fir" } +qsc_fir_transforms = { path = "../qsc_fir_transforms" } qsc_hir = { path = "../qsc_hir" } qsc_passes = { path = "../qsc_passes" } qsc_parse = { path = "../qsc_parse" } diff --git a/source/compiler/qsc/src/codegen.rs b/source/compiler/qsc/src/codegen.rs index 5538d07f39..7145dd5b2b 100644 --- a/source/compiler/qsc/src/codegen.rs +++ b/source/compiler/qsc/src/codegen.rs @@ -11,16 +11,1384 @@ pub mod qsharp { pub mod qir { use qsc_codegen::qir::{fir_to_qir, fir_to_rir}; + use qsc_eval::val::Value; + use qsc_fir::fir::Package; use qsc_data_structures::{ - error::WithSource, language_features::LanguageFeatures, source::SourceMap, - target::TargetCapabilityFlags, + error::WithSource, functors::FunctorApp, language_features::LanguageFeatures, + source::SourceMap, target::TargetCapabilityFlags, }; use qsc_frontend::compile::{Dependencies, PackageStore}; use qsc_partial_eval::{PartialEvalConfig, ProgramEntry}; - use qsc_passes::{PackageType, PassContext}; + use qsc_passes::{PackageType, PassContext, run_rca_for_callable}; + use rustc_hash::FxHashSet; use crate::interpret::Error; + + /// Flat Intermediate Representation (FIR) ready for QIR/RIR code generation. + /// + /// Contains: + /// - `fir_store`: Complete lowered FIR package store after all compiler passes + /// - `fir_package_id`: Main package ID within the store + /// - `compute_properties`: Resource analysis (qubit/instruction counts, etc.) + /// + /// Invariants (when created with full pipeline): + /// - No type parameters remain (monomorphization complete) + /// - No return statements (return unification complete) + /// - No arrow types or closures (defunctionalization complete) + /// - No UDT types (UDT erasure complete) + /// - Execution graphs fully populated + pub struct CodegenFir { + pub fir_store: qsc_fir::fir::PackageStore, + pub fir_package_id: qsc_fir::fir::PackageId, + pub compute_properties: qsc_rca::PackageStoreComputeProperties, + } + + /// Extracts the entry point expression from codegen FIR. + /// + /// Forms a `ProgramEntry` suitable for downstream codegen (QIR, RIR generation) + /// by combining the entry expression and its associated execution graph. + pub(crate) fn entry_from_codegen_fir(prepared_fir: &CodegenFir) -> ProgramEntry { + let package = prepared_fir.fir_store.get(prepared_fir.fir_package_id); + ProgramEntry { + exec_graph: package.entry_exec_graph.clone(), + expr: ( + prepared_fir.fir_package_id, + package + .entry + .expect("package must have an entry expression"), + ) + .into(), + } + } + + fn clone_fir_package(package: &Package) -> Package { + Package { + items: package.items.clone(), + entry: package.entry, + entry_exec_graph: package.entry_exec_graph.clone(), + blocks: package.blocks.clone(), + exprs: package.exprs.clone(), + pats: package.pats.clone(), + stmts: package.stmts.clone(), + } + } + + fn clone_fir_store(fir_store: &qsc_fir::fir::PackageStore) -> qsc_fir::fir::PackageStore { + let mut cloned_store = qsc_fir::fir::PackageStore::new(); + for (package_id, package) in fir_store { + cloned_store.insert(package_id, clone_fir_package(package)); + } + cloned_store + } + + fn lower_to_fir( + package_store: &PackageStore, + package_id: qsc_hir::hir::PackageId, + package_override: Option<&qsc_hir::hir::Package>, + ) -> ( + qsc_fir::fir::PackageStore, + qsc_fir::fir::PackageId, + qsc_fir::assigner::Assigner, + ) { + if let Some(package_override) = package_override { + let mut fir_store = qsc_fir::fir::PackageStore::new(); + let mut fir_assigner = qsc_fir::assigner::Assigner::new(); + + for (id, unit) in package_store { + let hir_package = if id == package_id { + package_override + } else { + &unit.package + }; + + let mut lowerer = qsc_lowerer::Lowerer::new(); + let fir_package = if id == package_id { + let mut fir_package = Package { + items: Default::default(), + entry: None, + entry_exec_graph: Default::default(), + blocks: Default::default(), + exprs: Default::default(), + pats: Default::default(), + stmts: Default::default(), + }; + lowerer.lower_and_update_package(&mut fir_package, hir_package); + fir_package.entry_exec_graph = lowerer.take_exec_graph(); + fir_package + } else { + lowerer.lower_package(hir_package, &fir_store) + }; + if id == package_id { + fir_assigner = lowerer.into_assigner(); + } + fir_store.insert(qsc_lowerer::map_hir_package_to_fir(id), fir_package); + } + + ( + fir_store, + qsc_lowerer::map_hir_package_to_fir(package_id), + fir_assigner, + ) + } else { + qsc_passes::lower_hir_to_fir(package_store, package_id) + } + } + + /// Runs the full FIR transformation pipeline through all stages. + /// + /// Applies compiler passes (monomorphization, defunctionalization, UDT erasure, etc.) + /// to produce codegen-ready FIR satisfying full invariants. + pub fn run_codegen_pipeline( + package_store: &PackageStore, + package_id: qsc_hir::hir::PackageId, + fir_store: &mut qsc_fir::fir::PackageStore, + fir_package_id: qsc_fir::fir::PackageId, + ) -> Result<(), Vec> { + run_codegen_pipeline_to( + package_store, + package_id, + fir_store, + fir_package_id, + qsc_fir_transforms::PipelineStage::Full, + &[], + ) + } + + /// Runs the FIR pipeline up to a specified stage with optional item pinning. + /// + /// Allows fine-grained control over pipeline execution: + /// - `stage`: Which pipeline stage to stop at (e.g., `PipelineStage::Full` for all passes) + /// - `pinned_items`: Callables to preserve even if not reached from entry + /// (useful for callable arguments that might otherwise be eliminated by DCE) + /// + /// This is critical for higher-order function support: when a callable is passed + /// as an argument, it may not be directly reachable from entry and would normally be + /// removed during dead-code elimination. Pinning preserves these for specialization. + pub fn run_codegen_pipeline_to( + package_store: &PackageStore, + package_id: qsc_hir::hir::PackageId, + fir_store: &mut qsc_fir::fir::PackageStore, + fir_package_id: qsc_fir::fir::PackageId, + stage: qsc_fir_transforms::PipelineStage, + pinned_items: &[qsc_fir::fir::StoreItemId], + ) -> Result<(), Vec> { + // CONTRACT: On success, `run_pipeline_to` with `PipelineStage::Full` produces FIR + // satisfying `InvariantLevel::PostAll`: + // - No `Ty::Param` in reachable code (monomorphization completed). + // - No `ExprKind::Return` in reachable code (return unification completed). + // - No `Ty::Arrow` params / `ExprKind::Closure` (defunctionalization completed). + // - No `Ty::Udt` / `ExprKind::Struct` / `Field::Path` (UDT erasure completed). + // - All exec-graph ranges populated (exec-graph rebuild completed). + // Downstream codegen (QIR lowering, partial evaluation) assumes these invariants hold. + // See `qsc_fir_transforms::invariants::check` for the authoritative checker. + let pipeline_errors = + qsc_fir_transforms::run_pipeline_to(fir_store, fir_package_id, stage, pinned_items); + if !pipeline_errors.is_empty() { + let source_package = package_store + .get(package_id) + .expect("package should be in store"); + return Err(pipeline_errors + .into_iter() + .map(|e| Error::FirTransform(WithSource::from_map(&source_package.sources, e))) + .collect()); + } + + Ok(()) + } + + fn map_pass_errors( + package_store: &PackageStore, + package_id: qsc_hir::hir::PackageId, + errors: Vec, + ) -> Vec { + let source_package = package_store + .get(package_id) + .expect("package should be in store"); + + errors + .into_iter() + .map(|e| Error::Pass(WithSource::from_map(&source_package.sources, e))) + .collect() + } + + fn validate_callable_capabilities( + package_store: &PackageStore, + fir_store: &qsc_fir::fir::PackageStore, + compute_properties: &qsc_rca::PackageStoreComputeProperties, + callable: qsc_fir::fir::StoreItemId, + capabilities: TargetCapabilityFlags, + ) -> Result<(), Vec> { + let errors = run_rca_for_callable(fir_store, compute_properties, callable, capabilities); + if errors.is_empty() { + Ok(()) + } else { + Err(map_pass_errors( + package_store, + qsc_lowerer::map_fir_package_to_hir(callable.package), + errors, + )) + } + } + + /// Returns true if a type is, or structurally contains, a callable arrow type. + /// + /// Arrays, tuples, and UDT pure types are traversed recursively so callers can + /// detect callable fields even before UDT erasure has normalized the type shape. + fn ty_contains_arrow(ty: &qsc_fir::ty::Ty, fir_store: &qsc_fir::fir::PackageStore) -> bool { + match ty { + qsc_fir::ty::Ty::Array(item) => ty_contains_arrow(item, fir_store), + qsc_fir::ty::Ty::Arrow(_) => true, + qsc_fir::ty::Ty::Tuple(items) => { + items.iter().any(|item| ty_contains_arrow(item, fir_store)) + } + qsc_fir::ty::Ty::Udt(res) => { + let qsc_fir::fir::Res::Item(item_id) = res else { + return false; + }; + let package = fir_store.get(item_id.package); + let item = package + .items + .get(item_id.item) + .expect("UDT item should exist"); + let qsc_fir::fir::ItemKind::Ty(_, udt) = &item.kind else { + return false; + }; + ty_contains_arrow(&udt.get_pure_ty(), fir_store) + } + qsc_fir::ty::Ty::Infer(_) + | qsc_fir::ty::Ty::Param(_) + | qsc_fir::ty::Ty::Prim(_) + | qsc_fir::ty::Ty::Err => false, + } + } + + fn callable_has_arrow_input( + fir_store: &qsc_fir::fir::PackageStore, + callable: qsc_hir::hir::ItemId, + ) -> bool { + use qsc_fir::fir::{Global, PackageLookup}; + + let callable_store_id = qsc_fir::fir::StoreItemId { + package: qsc_lowerer::map_hir_package_to_fir(callable.package), + item: qsc_lowerer::map_hir_local_item_to_fir(callable.item), + }; + + let package = fir_store.get(callable_store_id.package); + let Some(Global::Callable(callable_decl)) = package.get_global(callable_store_id.item) + else { + panic!("callable should exist in lowered package"); + }; + + ty_contains_arrow(&package.get_pat(callable_decl.input).ty, fir_store) + } + + fn seed_entry_with_callable( + fir_store: &mut qsc_fir::fir::PackageStore, + fir_package_id: qsc_fir::fir::PackageId, + callable: qsc_hir::hir::ItemId, + assigner: &mut qsc_fir::assigner::Assigner, + ) { + let callable_store_id = qsc_fir::fir::StoreItemId { + package: qsc_lowerer::map_hir_package_to_fir(callable.package), + item: qsc_lowerer::map_hir_local_item_to_fir(callable.item), + }; + + let (span, ty) = { + use qsc_fir::fir::{Global, PackageLookup}; + + let package = fir_store.get(callable_store_id.package); + let Some(Global::Callable(callable_decl)) = package.get_global(callable_store_id.item) + else { + panic!("callable should exist in lowered package"); + }; + + let input = package.get_pat(callable_decl.input).ty.clone(); + let ty = qsc_fir::ty::Ty::Arrow(Box::new(qsc_fir::ty::Arrow { + kind: callable_decl.kind, + input: Box::new(input), + output: Box::new(callable_decl.output.clone()), + functors: qsc_fir::ty::FunctorSet::Value(callable_decl.functors), + })); + + (callable_decl.span, ty) + }; + + let entry_expr_id = assigner.next_expr(); + let package = fir_store.get_mut(fir_package_id); + package.exprs.insert( + entry_expr_id, + qsc_fir::fir::Expr { + id: entry_expr_id, + span, + ty, + kind: qsc_fir::fir::ExprKind::Var( + qsc_fir::fir::Res::Item(qsc_fir::fir::ItemId { + package: callable_store_id.package, + item: callable_store_id.item, + }), + Vec::new(), + ), + exec_graph_range: qsc_fir::fir::ExecGraphIdx::ZERO + ..qsc_fir::fir::ExecGraphIdx::ZERO, + }, + ); + package.entry = Some(entry_expr_id); + package.entry_exec_graph = Default::default(); + } + + fn callable_expr_span_and_ty( + fir_store: &qsc_fir::fir::PackageStore, + callable_store_id: qsc_fir::fir::StoreItemId, + ) -> (qsc_data_structures::span::Span, qsc_fir::ty::Ty) { + use qsc_fir::fir::{Global, PackageLookup}; + + let package = fir_store.get(callable_store_id.package); + let Some(Global::Callable(callable_decl)) = package.get_global(callable_store_id.item) + else { + panic!("callable should exist in lowered package"); + }; + + let input = package.get_pat(callable_decl.input).ty.clone(); + let ty = qsc_fir::ty::Ty::Arrow(Box::new(qsc_fir::ty::Arrow { + kind: callable_decl.kind, + input: Box::new(input), + output: Box::new(callable_decl.output.clone()), + functors: qsc_fir::ty::FunctorSet::Value(callable_decl.functors), + })); + + (callable_decl.span, ty) + } + + fn seed_entry_with_callables( + fir_store: &mut qsc_fir::fir::PackageStore, + fir_package_id: qsc_fir::fir::PackageId, + callables: &FxHashSet, + ) { + if callables.is_empty() { + return; + } + + let mut assigner = qsc_fir::assigner::Assigner::from_package(fir_store.get(fir_package_id)); + + let mut entry_exprs = Vec::with_capacity(callables.len()); + let mut entry_tys = Vec::with_capacity(callables.len()); + let mut entry_span = None; + + for callable in callables { + let (span, ty) = callable_expr_span_and_ty(fir_store, *callable); + let expr_id = assigner.next_expr(); + let package = fir_store.get_mut(fir_package_id); + package.exprs.insert( + expr_id, + qsc_fir::fir::Expr { + id: expr_id, + span, + ty: ty.clone(), + kind: qsc_fir::fir::ExprKind::Var( + qsc_fir::fir::Res::Item(qsc_fir::fir::ItemId { + package: callable.package, + item: callable.item, + }), + Vec::new(), + ), + exec_graph_range: qsc_fir::fir::ExecGraphIdx::ZERO + ..qsc_fir::fir::ExecGraphIdx::ZERO, + }, + ); + entry_exprs.push(expr_id); + entry_tys.push(ty); + entry_span.get_or_insert(span); + } + + let entry_expr_id = if entry_exprs.len() == 1 { + entry_exprs[0] + } else { + let entry_expr_id = assigner.next_expr(); + let package = fir_store.get_mut(fir_package_id); + package.exprs.insert( + entry_expr_id, + qsc_fir::fir::Expr { + id: entry_expr_id, + span: entry_span.expect("tuple entry should have a span"), + ty: qsc_fir::ty::Ty::Tuple(entry_tys), + kind: qsc_fir::fir::ExprKind::Tuple(entry_exprs), + exec_graph_range: qsc_fir::fir::ExecGraphIdx::ZERO + ..qsc_fir::fir::ExecGraphIdx::ZERO, + }, + ); + entry_expr_id + }; + + let package = fir_store.get_mut(fir_package_id); + package.entry = Some(entry_expr_id); + package.entry_exec_graph = Default::default(); + } + + /// Builds a pre-computed map of callable types for all Global/Closure values in `args`. + /// + /// This allows `lower_value_to_expr` to look up arrow types without holding an immutable + /// reference to the package store while also mutating a package. + fn build_callable_type_map( + fir_store: &qsc_fir::fir::PackageStore, + callables: &FxHashSet, + ) -> rustc_hash::FxHashMap { + let mut map = + rustc_hash::FxHashMap::with_capacity_and_hasher(callables.len(), Default::default()); + for id in callables { + let (_, ty) = callable_expr_span_and_ty(fir_store, *id); + map.insert(*id, ty); + } + map + } + + /// Seeds the package entry with a synthetic `Call(target, args)` expression. + /// + /// Builds args matching the target callable's pure input type: callable-typed positions + /// are filled with Var references to the concrete callables from the `args` Value; + /// non-callable positions get typed placeholder literals (which are never evaluated — + /// they exist only to make the Call structurally valid for defunctionalization). + fn seed_entry_with_call_to_target( + fir_store: &mut qsc_fir::fir::PackageStore, + fir_package_id: qsc_fir::fir::PackageId, + target_callable: qsc_fir::fir::StoreItemId, + args: &Value, + callable_types: &rustc_hash::FxHashMap, + ) { + use qsc_fir::fir::{Global, PackageLookup}; + + // Pre-compute target's arrow type and input pattern type (immutable borrow of store). + let package = fir_store.get(target_callable.package); + let Some(Global::Callable(callable_decl)) = package.get_global(target_callable.item) else { + panic!("target callable must exist in lowered package"); + }; + let span = callable_decl.span; + let input_pat = package.get_pat(callable_decl.input); + let input_ty = resolve_functor_params(&resolve_udt_ty(fir_store, &input_pat.ty)); + let output_ty = resolve_functor_params(&resolve_udt_ty(fir_store, &callable_decl.output)); + let arrow_ty = qsc_fir::ty::Ty::Arrow(Box::new(qsc_fir::ty::Arrow { + kind: callable_decl.kind, + input: Box::new(input_ty.clone()), + output: Box::new(output_ty.clone()), + functors: qsc_fir::ty::FunctorSet::Value(callable_decl.functors), + })); + + // Build concrete generic args for the callee Var so monomorphization can + // resolve FunctorSet::Param in the specialized clone's body types. + let generic_args = build_concrete_generic_args(&callable_decl.generics); + + // Build assigner from the package's current ID counters. + let mut assigner = qsc_fir::assigner::Assigner::from_package(fir_store.get(fir_package_id)); + + // Get the package mutably and build args expression matching the input type. + let package = fir_store.get_mut(fir_package_id); + let args_expr_id = + build_synthetic_args(package, &mut assigner, &input_ty, args, callable_types); + + // Create callee Var expression referencing the target callable. + let callee_expr_id = assigner.next_expr(); + package.exprs.insert( + callee_expr_id, + qsc_fir::fir::Expr { + id: callee_expr_id, + span, + ty: arrow_ty, + kind: qsc_fir::fir::ExprKind::Var( + qsc_fir::fir::Res::Item(qsc_fir::fir::ItemId { + package: target_callable.package, + item: target_callable.item, + }), + generic_args, + ), + exec_graph_range: qsc_fir::fir::ExecGraphIdx::ZERO + ..qsc_fir::fir::ExecGraphIdx::ZERO, + }, + ); + + // Create Call expression: Call(callee, args) with output type. + let call_expr_id = assigner.next_expr(); + package.exprs.insert( + call_expr_id, + qsc_fir::fir::Expr { + id: call_expr_id, + span, + ty: output_ty, + kind: qsc_fir::fir::ExprKind::Call(callee_expr_id, args_expr_id), + exec_graph_range: qsc_fir::fir::ExecGraphIdx::ZERO + ..qsc_fir::fir::ExecGraphIdx::ZERO, + }, + ); + + // Set entry to the synthetic Call. + package.entry = Some(call_expr_id); + package.entry_exec_graph = Default::default(); + } + + /// Builds an args expression matching the target's input type. + /// + /// For callable-typed positions, uses the corresponding callable from `args`. + /// For non-callable positions, uses `lower_value_to_expr` if the value is available + /// in `args`, otherwise creates a typed placeholder literal. + fn build_synthetic_args( + package: &mut qsc_fir::fir::Package, + assigner: &mut qsc_fir::assigner::Assigner, + input_ty: &qsc_fir::ty::Ty, + args: &Value, + callable_types: &rustc_hash::FxHashMap, + ) -> qsc_fir::fir::ExprId { + match input_ty { + qsc_fir::ty::Ty::Tuple(elem_tys) if elem_tys.is_empty() => { + // Unit input — create empty tuple expression. + let expr_id = assigner.next_expr(); + package.exprs.insert( + expr_id, + qsc_fir::fir::Expr { + id: expr_id, + span: qsc_data_structures::span::Span::default(), + ty: qsc_fir::ty::Ty::Tuple(Vec::new()), + kind: qsc_fir::fir::ExprKind::Tuple(Vec::new()), + exec_graph_range: qsc_fir::fir::ExecGraphIdx::ZERO + ..qsc_fir::fir::ExecGraphIdx::ZERO, + }, + ); + expr_id + } + qsc_fir::ty::Ty::Tuple(elem_tys) => { + // Multi-param input — walk each position. + // If args is a Tuple of same length, pair element-wise. + // Otherwise, match the first callable-typed position to args. + let arg_elems: Vec<&Value> = match args { + Value::Tuple(vs, _) if vs.len() == elem_tys.len() => vs.iter().collect(), + _ => { + // Args doesn't match tuple structure — build with + // args placed at the first arrow-typed position. + let mut elem_ids = Vec::with_capacity(elem_tys.len()); + let mut args_used = false; + for elem_ty in elem_tys { + if !args_used && ty_is_arrow_or_contains_arrow(elem_ty) { + elem_ids.push(lower_value_to_expr( + package, + assigner, + args, + callable_types, + )); + args_used = true; + } else { + elem_ids.push(make_placeholder_expr(package, assigner, elem_ty)); + } + } + let expr_id = assigner.next_expr(); + package.exprs.insert( + expr_id, + qsc_fir::fir::Expr { + id: expr_id, + span: qsc_data_structures::span::Span::default(), + ty: input_ty.clone(), + kind: qsc_fir::fir::ExprKind::Tuple(elem_ids), + exec_graph_range: qsc_fir::fir::ExecGraphIdx::ZERO + ..qsc_fir::fir::ExecGraphIdx::ZERO, + }, + ); + return expr_id; + } + }; + + // Element-wise matching: lower each arg against its declared type. + let mut elem_ids = Vec::with_capacity(elem_tys.len()); + for (elem_ty, arg_val) in elem_tys.iter().zip(arg_elems.iter()) { + elem_ids.push(build_synthetic_args( + package, + assigner, + elem_ty, + arg_val, + callable_types, + )); + } + let expr_id = assigner.next_expr(); + package.exprs.insert( + expr_id, + qsc_fir::fir::Expr { + id: expr_id, + span: qsc_data_structures::span::Span::default(), + ty: input_ty.clone(), + kind: qsc_fir::fir::ExprKind::Tuple(elem_ids), + exec_graph_range: qsc_fir::fir::ExecGraphIdx::ZERO + ..qsc_fir::fir::ExecGraphIdx::ZERO, + }, + ); + expr_id + } + qsc_fir::ty::Ty::Arrow(_) => { + // Arrow-typed position — the args must be a callable value. + lower_value_to_expr(package, assigner, args, callable_types) + } + _ => { + // Non-callable position — lower value if possible, otherwise placeholder. + match args { + Value::Qubit(_) | Value::Var(_) => { + make_placeholder_expr(package, assigner, input_ty) + } + _ => lower_value_to_expr(package, assigner, args, callable_types), + } + } + } + } + + /// Replaces UDT types with their pure structural FIR type, recursively. + /// + /// Synthetic call construction operates on the post-erasure shape so callable + /// fields hidden inside UDTs can be discovered by defunctionalization. + fn resolve_udt_ty( + fir_store: &qsc_fir::fir::PackageStore, + ty: &qsc_fir::ty::Ty, + ) -> qsc_fir::ty::Ty { + match ty { + qsc_fir::ty::Ty::Udt(qsc_fir::fir::Res::Item(item_id)) => { + let package = fir_store.get(item_id.package); + let item = package + .items + .get(item_id.item) + .expect("UDT item should exist"); + let qsc_fir::fir::ItemKind::Ty(_, udt) = &item.kind else { + return ty.clone(); + }; + resolve_udt_ty(fir_store, &udt.get_pure_ty()) + } + qsc_fir::ty::Ty::Tuple(elems) => qsc_fir::ty::Ty::Tuple( + elems + .iter() + .map(|elem| resolve_udt_ty(fir_store, elem)) + .collect(), + ), + qsc_fir::ty::Ty::Array(elem) => { + qsc_fir::ty::Ty::Array(Box::new(resolve_udt_ty(fir_store, elem))) + } + qsc_fir::ty::Ty::Arrow(arrow) => qsc_fir::ty::Ty::Arrow(Box::new(qsc_fir::ty::Arrow { + kind: arrow.kind, + input: Box::new(resolve_udt_ty(fir_store, &arrow.input)), + output: Box::new(resolve_udt_ty(fir_store, &arrow.output)), + functors: arrow.functors, + })), + _ => ty.clone(), + } + } + + /// Returns true if the type is an Arrow or contains an Arrow in tuple structure. + fn ty_is_arrow_or_contains_arrow(ty: &qsc_fir::ty::Ty) -> bool { + match ty { + qsc_fir::ty::Ty::Arrow(_) => true, + qsc_fir::ty::Ty::Tuple(elems) => elems.iter().any(ty_is_arrow_or_contains_arrow), + _ => false, + } + } + + /// Creates a typed placeholder expression for a non-callable input position. + /// + /// Uses `Lit(Int(0))` with the declared type. The placeholder is never evaluated — + /// it exists only to make the synthetic Call structurally valid for pipeline passes. + fn make_placeholder_expr( + package: &mut qsc_fir::fir::Package, + assigner: &mut qsc_fir::assigner::Assigner, + ty: &qsc_fir::ty::Ty, + ) -> qsc_fir::fir::ExprId { + let expr_id = assigner.next_expr(); + package.exprs.insert( + expr_id, + qsc_fir::fir::Expr { + id: expr_id, + span: qsc_data_structures::span::Span::default(), + ty: ty.clone(), + kind: qsc_fir::fir::ExprKind::Lit(qsc_fir::fir::Lit::Int(0)), + exec_graph_range: qsc_fir::fir::ExecGraphIdx::ZERO + ..qsc_fir::fir::ExecGraphIdx::ZERO, + }, + ); + expr_id + } + + /// Resolves `FunctorSet::Param` to `FunctorSet::Value(Empty)` recursively in a type. + /// + /// The lowerer may produce parametric functor sets for arrow-typed inputs. The synthetic + /// Call uses concrete types to satisfy post-mono invariants without requiring actual + /// monomorphization specialization of the pinned target. + fn resolve_functor_params(ty: &qsc_fir::ty::Ty) -> qsc_fir::ty::Ty { + match ty { + qsc_fir::ty::Ty::Arrow(arrow) => { + let functors = match arrow.functors { + qsc_fir::ty::FunctorSet::Param(_) | qsc_fir::ty::FunctorSet::Infer(_) => { + qsc_fir::ty::FunctorSet::Value(qsc_fir::ty::FunctorSetValue::Empty) + } + other @ qsc_fir::ty::FunctorSet::Value(_) => other, + }; + qsc_fir::ty::Ty::Arrow(Box::new(qsc_fir::ty::Arrow { + kind: arrow.kind, + input: Box::new(resolve_functor_params(&arrow.input)), + output: Box::new(resolve_functor_params(&arrow.output)), + functors, + })) + } + qsc_fir::ty::Ty::Tuple(elems) => { + qsc_fir::ty::Ty::Tuple(elems.iter().map(resolve_functor_params).collect()) + } + qsc_fir::ty::Ty::Array(inner) => { + qsc_fir::ty::Ty::Array(Box::new(resolve_functor_params(inner))) + } + other => other.clone(), + } + } + + /// Builds concrete generic args from a callable's generic parameter list. + /// + /// For each `TypeParameter::Functor`, produces `GenericArg::Functor(Value(Empty))`. + /// For each `TypeParameter::Ty`, produces `GenericArg::Ty(Tuple([]))` (unit). + /// These concrete args let monomorphization create a fully resolved specialization. + fn build_concrete_generic_args( + generics: &[qsc_fir::ty::TypeParameter], + ) -> Vec { + generics + .iter() + .map(|param| match param { + qsc_fir::ty::TypeParameter::Functor(_) => qsc_fir::ty::GenericArg::Functor( + qsc_fir::ty::FunctorSet::Value(qsc_fir::ty::FunctorSetValue::Empty), + ), + qsc_fir::ty::TypeParameter::Ty { .. } => { + qsc_fir::ty::GenericArg::Ty(qsc_fir::ty::Ty::Tuple(Vec::new())) + } + }) + .collect() + } + + /// Extracts the specialized target callable from the entry Call expression after pipeline. + /// + /// After defunctionalization, the entry Call's callee Var references the specialized + /// (post-defunc) version of the target callable. This function extracts that ID. + #[allow(dead_code)] + fn extract_target_from_entry_call( + fir_store: &qsc_fir::fir::PackageStore, + fir_package_id: qsc_fir::fir::PackageId, + ) -> qsc_fir::fir::StoreItemId { + let package = fir_store.get(fir_package_id); + let entry_id = package + .entry + .expect("package must have entry after pipeline"); + let entry_expr = package.exprs.get(entry_id).expect("entry expr must exist"); + + let qsc_fir::fir::ExprKind::Call(callee_id, _) = &entry_expr.kind else { + panic!( + "entry expression must be a Call after pipeline, found {:?}", + entry_expr.kind + ); + }; + + let callee_expr = package + .exprs + .get(*callee_id) + .expect("callee expr must exist"); + let qsc_fir::fir::ExprKind::Var(qsc_fir::fir::Res::Item(item_id), _) = &callee_expr.kind + else { + panic!( + "entry Call callee must be a Var(Res::Item(...)) after pipeline, found {:?}", + callee_expr.kind + ); + }; + + qsc_fir::fir::StoreItemId { + package: item_id.package, + item: item_id.item, + } + } + + /// Lowers an interpreter `Value` into a FIR expression for the synthetic entry. + /// + /// Scalar values become literals, aggregate values are lowered recursively, and + /// callable values are represented by global or closure variables with their + /// runtime functor application preserved. + #[allow(clippy::too_many_lines)] + fn lower_value_to_expr( + package: &mut qsc_fir::fir::Package, + assigner: &mut qsc_fir::assigner::Assigner, + value: &Value, + callable_types: &rustc_hash::FxHashMap, + ) -> qsc_fir::fir::ExprId { + let (kind, ty) = match value { + Value::Int(n) => ( + qsc_fir::fir::ExprKind::Lit(qsc_fir::fir::Lit::Int(*n)), + qsc_fir::ty::Ty::Prim(qsc_fir::ty::Prim::Int), + ), + Value::Double(d) => ( + qsc_fir::fir::ExprKind::Lit(qsc_fir::fir::Lit::Double(*d)), + qsc_fir::ty::Ty::Prim(qsc_fir::ty::Prim::Double), + ), + Value::Bool(b) => ( + qsc_fir::fir::ExprKind::Lit(qsc_fir::fir::Lit::Bool(*b)), + qsc_fir::ty::Ty::Prim(qsc_fir::ty::Prim::Bool), + ), + Value::BigInt(b) => ( + qsc_fir::fir::ExprKind::Lit(qsc_fir::fir::Lit::BigInt(b.clone())), + qsc_fir::ty::Ty::Prim(qsc_fir::ty::Prim::BigInt), + ), + Value::Pauli(p) => ( + qsc_fir::fir::ExprKind::Lit(qsc_fir::fir::Lit::Pauli(*p)), + qsc_fir::ty::Ty::Prim(qsc_fir::ty::Prim::Pauli), + ), + Value::Result(qsc_eval::val::Result::Val(b)) => ( + qsc_fir::fir::ExprKind::Lit(qsc_fir::fir::Lit::Result(if *b { + qsc_fir::fir::Result::One + } else { + qsc_fir::fir::Result::Zero + })), + qsc_fir::ty::Ty::Prim(qsc_fir::ty::Prim::Result), + ), + Value::String(s) => ( + qsc_fir::fir::ExprKind::String(vec![qsc_fir::fir::StringComponent::Lit(s.clone())]), + qsc_fir::ty::Ty::Prim(qsc_fir::ty::Prim::String), + ), + Value::Tuple(vs, _) => { + let mut lowered_ids = Vec::with_capacity(vs.len()); + let mut lowered_tys = Vec::with_capacity(vs.len()); + for v in vs.iter() { + let id = lower_value_to_expr(package, assigner, v, callable_types); + lowered_tys.push(package.exprs.get(id).expect("just inserted").ty.clone()); + lowered_ids.push(id); + } + ( + qsc_fir::fir::ExprKind::Tuple(lowered_ids), + qsc_fir::ty::Ty::Tuple(lowered_tys), + ) + } + Value::Array(vs) => { + let mut lowered_ids = Vec::with_capacity(vs.len()); + for v in vs.iter() { + lowered_ids.push(lower_value_to_expr(package, assigner, v, callable_types)); + } + let elem_ty = lowered_ids.first().map_or(qsc_fir::ty::Ty::Err, |id| { + package.exprs.get(*id).expect("just inserted").ty.clone() + }); + ( + qsc_fir::fir::ExprKind::Array(lowered_ids), + qsc_fir::ty::Ty::Array(Box::new(elem_ty)), + ) + } + Value::Range(r) => { + let lower_opt = |opt: Option, + pkg: &mut qsc_fir::fir::Package, + a: &mut qsc_fir::assigner::Assigner| + -> Option { + opt.map(|n| { + let id = a.next_expr(); + pkg.exprs.insert( + id, + qsc_fir::fir::Expr { + id, + span: qsc_data_structures::span::Span::default(), + ty: qsc_fir::ty::Ty::Prim(qsc_fir::ty::Prim::Int), + kind: qsc_fir::fir::ExprKind::Lit(qsc_fir::fir::Lit::Int(n)), + exec_graph_range: qsc_fir::fir::ExecGraphIdx::ZERO + ..qsc_fir::fir::ExecGraphIdx::ZERO, + }, + ); + id + }) + }; + let start = lower_opt(r.start, package, assigner); + let step = lower_opt(Some(r.step), package, assigner); + let end = lower_opt(r.end, package, assigner); + ( + qsc_fir::fir::ExprKind::Range(start, step, end), + qsc_fir::ty::Ty::Prim(qsc_fir::ty::Prim::Range), + ) + } + Value::Global(id, functor) => { + return lower_global_to_expr(package, assigner, *id, *functor, callable_types); + } + Value::Closure(c) => { + return lower_closure_to_expr(package, assigner, c, callable_types); + } + _ => panic!("cannot lower {value:?} to FIR expression"), + }; + + let expr_id = assigner.next_expr(); + package.exprs.insert( + expr_id, + qsc_fir::fir::Expr { + id: expr_id, + span: qsc_data_structures::span::Span::default(), + ty, + kind, + exec_graph_range: qsc_fir::fir::ExecGraphIdx::ZERO + ..qsc_fir::fir::ExecGraphIdx::ZERO, + }, + ); + expr_id + } + + /// Lowers a global callable value to a FIR variable expression. + /// + /// The callable's stored `FunctorApp` is applied as FIR functor wrappers so + /// adjoint and controlled runtime values survive the synthetic entry path. + fn lower_global_to_expr( + package: &mut qsc_fir::fir::Package, + assigner: &mut qsc_fir::assigner::Assigner, + id: qsc_fir::fir::StoreItemId, + functor: FunctorApp, + callable_types: &rustc_hash::FxHashMap, + ) -> qsc_fir::fir::ExprId { + let ty = callable_types + .get(&id) + .expect("Global callable type must be pre-computed") + .clone(); + let expr_id = assigner.next_expr(); + package.exprs.insert( + expr_id, + qsc_fir::fir::Expr { + id: expr_id, + span: qsc_data_structures::span::Span::default(), + ty: ty.clone(), + kind: qsc_fir::fir::ExprKind::Var( + qsc_fir::fir::Res::Item(qsc_fir::fir::ItemId { + package: id.package, + item: id.item, + }), + Vec::new(), + ), + exec_graph_range: qsc_fir::fir::ExecGraphIdx::ZERO + ..qsc_fir::fir::ExecGraphIdx::ZERO, + }, + ); + wrap_expr_with_functor_app(package, assigner, expr_id, &ty, functor) + } + + /// Wraps a callable expression with the FIR functor operations in `functor`. + /// + /// Adjoint is applied before each controlled application to match the runtime + /// `FunctorApp` representation used by interpreter values. + fn wrap_expr_with_functor_app( + package: &mut qsc_fir::fir::Package, + assigner: &mut qsc_fir::assigner::Assigner, + expr_id: qsc_fir::fir::ExprId, + ty: &qsc_fir::ty::Ty, + functor: FunctorApp, + ) -> qsc_fir::fir::ExprId { + let mut current_id = expr_id; + if functor.adjoint { + current_id = wrap_expr_with_functor( + package, + assigner, + current_id, + ty, + qsc_fir::fir::Functor::Adj, + ); + } + for _ in 0..functor.controlled { + current_id = wrap_expr_with_functor( + package, + assigner, + current_id, + ty, + qsc_fir::fir::Functor::Ctl, + ); + } + current_id + } + + /// Creates a FIR unary functor expression around an existing callable expression. + fn wrap_expr_with_functor( + package: &mut qsc_fir::fir::Package, + assigner: &mut qsc_fir::assigner::Assigner, + inner_id: qsc_fir::fir::ExprId, + ty: &qsc_fir::ty::Ty, + functor: qsc_fir::fir::Functor, + ) -> qsc_fir::fir::ExprId { + let expr_id = assigner.next_expr(); + package.exprs.insert( + expr_id, + qsc_fir::fir::Expr { + id: expr_id, + span: qsc_data_structures::span::Span::default(), + ty: ty.clone(), + kind: qsc_fir::fir::ExprKind::UnOp(qsc_fir::fir::UnOp::Functor(functor), inner_id), + exec_graph_range: qsc_fir::fir::ExecGraphIdx::ZERO + ..qsc_fir::fir::ExecGraphIdx::ZERO, + }, + ); + expr_id + } + + /// Lowers a captureless closure to its underlying callable variable expression. + /// + /// Capturing closures take the pinned fallback path before this is called, so + /// this helper only has to preserve the closure target and runtime functor app. + fn lower_closure_to_expr( + package: &mut qsc_fir::fir::Package, + assigner: &mut qsc_fir::assigner::Assigner, + closure: &qsc_eval::val::Closure, + callable_types: &rustc_hash::FxHashMap, + ) -> qsc_fir::fir::ExprId { + // For the synthetic entry, we emit a Var referencing the closure's underlying + // callable. Captures are irrelevant for pipeline reachability — defunc handles + // specialization. Both captureless and capturing closures use the same Var form. + let ty = callable_types + .get(&closure.id) + .expect("Closure callable type must be pre-computed") + .clone(); + let kind = qsc_fir::fir::ExprKind::Var( + qsc_fir::fir::Res::Item(qsc_fir::fir::ItemId { + package: closure.id.package, + item: closure.id.item, + }), + Vec::new(), + ); + + let expr_id = assigner.next_expr(); + package.exprs.insert( + expr_id, + qsc_fir::fir::Expr { + id: expr_id, + span: qsc_data_structures::span::Span::default(), + ty: ty.clone(), + kind, + exec_graph_range: qsc_fir::fir::ExecGraphIdx::ZERO + ..qsc_fir::fir::ExecGraphIdx::ZERO, + }, + ); + wrap_expr_with_functor_app(package, assigner, expr_id, &ty, closure.functor) + } + + fn collect_concrete_qsharp_callables( + value: &Value, + callables: &mut FxHashSet, + ) { + match value { + Value::Array(values) => values + .iter() + .for_each(|value| collect_concrete_qsharp_callables(value, callables)), + Value::Closure(closure) => { + if !callables.contains(&closure.id) { + callables.insert(closure.id); + } + closure + .fixed_args + .iter() + .for_each(|value| collect_concrete_qsharp_callables(value, callables)); + } + Value::Global(store_item_id, _) => { + if !callables.contains(store_item_id) { + callables.insert(*store_item_id); + } + } + Value::Tuple(values, _) => values + .iter() + .for_each(|value| collect_concrete_qsharp_callables(value, callables)), + Value::BigInt(_) + | Value::Bool(_) + | Value::Double(_) + | Value::Int(_) + | Value::Pauli(_) + | Value::Qubit(_) + | Value::Range(_) + | Value::Result(_) + | Value::String(_) + | Value::Var(_) => {} + } + } + + /// Prepares codegen FIR when a callable is invoked with concrete argument values. + /// + /// Uses a synthetic `Call(Var(target), args)` entry expression when callable args + /// can be represented as FIR values, making the target and args entry-reachable for full + /// pipeline participation. Falls back to a pin-based approach when: + /// - Args contain closures with captures (partial applications require capture context + /// that can't be represented in the synthetic Call) + /// + /// The original target is pinned for DCE survival so that `fir_to_qir_from_callable` + /// can still use the original ID for partial evaluation. + pub fn prepare_codegen_fir_from_callable_args( + package_store: &PackageStore, + callable: qsc_hir::hir::ItemId, + args: &Value, + capabilities: TargetCapabilityFlags, + ) -> Result> { + let mut concrete_callables = FxHashSet::default(); + collect_concrete_qsharp_callables(args, &mut concrete_callables); + + if concrete_callables.is_empty() { + return prepare_codegen_fir_from_callable(package_store, callable, capabilities); + } + + // Closures with captures represent partial applications whose capture context + // can't be lowered into a synthetic Call expression yet. They still use the + // pin-based approach where partial eval handles specialization at QIR generation time. + if has_closure_with_captures(args) { + return prepare_codegen_fir_from_callable_args_pinned( + package_store, + callable, + args, + capabilities, + concrete_callables, + ); + } + + let (mut fir_store, fir_package_id, _assigner) = + lower_to_fir(package_store, callable.package, None); + + let target_callable = qsc_fir::fir::StoreItemId { + package: qsc_lowerer::map_hir_package_to_fir(callable.package), + item: qsc_lowerer::map_hir_local_item_to_fir(callable.item), + }; + + // Pre-compute callable type map (immutable store access) before mutating. + let callable_types = build_callable_type_map(&fir_store, &concrete_callables); + + // Build synthetic Call(Var(target), args) as the entry expression. + // This makes the target and all callable args entry-reachable for pipeline transforms. + seed_entry_with_call_to_target( + &mut fir_store, + fir_package_id, + target_callable, + args, + &callable_types, + ); + + // Pin the original target for DCE survival. After defunc rewrites the entry + // Call callee to reference the specialized version, the original target becomes + // unreachable. Pinning keeps it alive for `fir_to_qir_from_callable` which + // uses the original ID with original-shaped args. + run_codegen_pipeline_to( + package_store, + callable.package, + &mut fir_store, + fir_package_id, + qsc_fir_transforms::PipelineStage::Full, + &[target_callable], + )?; + let compute_properties = qsc_rca::Analyzer::init(&fir_store, capabilities).analyze_all(); + validate_callable_capabilities( + package_store, + &fir_store, + &compute_properties, + target_callable, + capabilities, + )?; + + Ok(CodegenFir { + fir_store, + fir_package_id, + compute_properties, + }) + } + + /// Pin-based fallback for callable args containing closures with captures. + /// + /// Seeds concrete (non-arrow-input) callables into the entry for reachability, + /// pins arrow-input callables and the target for DCE survival, and lets + /// `fir_to_qir_from_callable` handle specialization at QIR generation time. + fn prepare_codegen_fir_from_callable_args_pinned( + package_store: &PackageStore, + callable: qsc_hir::hir::ItemId, + _args: &Value, + capabilities: TargetCapabilityFlags, + mut concrete_callables: FxHashSet, + ) -> Result> { + let (mut fir_store, fir_package_id, _assigner) = + lower_to_fir(package_store, callable.package, None); + + let mut pinned_callables: Vec = Vec::new(); + concrete_callables.retain(|store_item_id| { + let hir_item_id = qsc_hir::hir::ItemId { + package: qsc_lowerer::map_fir_package_to_hir(store_item_id.package), + item: qsc_lowerer::map_fir_local_item_to_hir(store_item_id.item), + }; + if callable_has_arrow_input(&fir_store, hir_item_id) { + pinned_callables.push(*store_item_id); + false + } else { + true + } + }); + + let target_callable = qsc_fir::fir::StoreItemId { + package: qsc_lowerer::map_hir_package_to_fir(callable.package), + item: qsc_lowerer::map_hir_local_item_to_fir(callable.item), + }; + + seed_entry_with_callables(&mut fir_store, fir_package_id, &concrete_callables); + pinned_callables.push(target_callable); + run_codegen_pipeline_to( + package_store, + callable.package, + &mut fir_store, + fir_package_id, + qsc_fir_transforms::PipelineStage::Full, + &pinned_callables, + )?; + let compute_properties = qsc_rca::Analyzer::init(&fir_store, capabilities).analyze_all(); + validate_callable_capabilities( + package_store, + &fir_store, + &compute_properties, + target_callable, + capabilities, + )?; + + Ok(CodegenFir { + fir_store, + fir_package_id, + compute_properties, + }) + } + + /// Returns `true` if the value tree contains any closures with captures. + fn has_closure_with_captures(value: &Value) -> bool { + match value { + Value::Closure(c) => !c.fixed_args.is_empty(), + Value::Tuple(vs, _) => vs.iter().any(has_closure_with_captures), + Value::Array(vs) => vs.iter().any(has_closure_with_captures), + _ => false, + } + } + + fn prepare_codegen_fir_inner( + package_store: &PackageStore, + package_id: qsc_hir::hir::PackageId, + package_override: Option<&qsc_hir::hir::Package>, + capabilities: TargetCapabilityFlags, + ) -> Result> { + let (fir_store, fir_package_id, _) = + lower_to_fir(package_store, package_id, package_override); + + prepare_codegen_fir_from_lowered_store( + package_store, + package_id, + fir_store, + fir_package_id, + capabilities, + ) + } + + fn prepare_codegen_fir_from_lowered_store( + package_store: &PackageStore, + package_id: qsc_hir::hir::PackageId, + mut fir_store: qsc_fir::fir::PackageStore, + fir_package_id: qsc_fir::fir::PackageId, + capabilities: TargetCapabilityFlags, + ) -> Result> { + run_codegen_pipeline(package_store, package_id, &mut fir_store, fir_package_id)?; + + let compute_properties = + PassContext::run_fir_passes_on_fir(&fir_store, fir_package_id, capabilities) + .map_err(|errors| map_pass_errors(package_store, package_id, errors))?; + + Ok(CodegenFir { + fir_store, + fir_package_id, + compute_properties, + }) + } + + pub fn prepare_codegen_fir( + package_store: &PackageStore, + package_id: qsc_hir::hir::PackageId, + capabilities: TargetCapabilityFlags, + ) -> Result> { + prepare_codegen_fir_inner(package_store, package_id, None, capabilities) + } + + pub fn prepare_codegen_fir_from_fir_store( + package_store: &PackageStore, + package_id: qsc_hir::hir::PackageId, + fir_store: &qsc_fir::fir::PackageStore, + fir_package_id: qsc_fir::fir::PackageId, + capabilities: TargetCapabilityFlags, + ) -> Result> { + prepare_codegen_fir_from_lowered_store( + package_store, + package_id, + clone_fir_store(fir_store), + fir_package_id, + capabilities, + ) + } + + /// Prepares codegen FIR for a single callable without inline arguments. + /// + /// Used when a callable is referenced but its concrete argument values are not yet known. + /// For callables with arrow-typed inputs, skips the full pipeline to preserve abstract + /// higher-order structure that will be specialized later via `prepare_codegen_fir_from_callable_args`. + pub fn prepare_codegen_fir_from_callable( + package_store: &PackageStore, + callable: qsc_hir::hir::ItemId, + capabilities: TargetCapabilityFlags, + ) -> Result> { + let (mut fir_store, fir_package_id, mut assigner) = + lower_to_fir(package_store, callable.package, None); + + if callable_has_arrow_input(&fir_store, callable) { + // Callable-based codegen receives the concrete callable arguments later through + // partially_evaluate_call. Running the FIR transform pipeline from a bare callable + // reference loses that higher-order call-site information and can leave functor- + // parameterized arrow types unspecialized. + return Ok(CodegenFir { + compute_properties: qsc_rca::Analyzer::init(&fir_store, capabilities).analyze_all(), + fir_store, + fir_package_id, + }); + } + + seed_entry_with_callable(&mut fir_store, fir_package_id, callable, &mut assigner); + run_codegen_pipeline( + package_store, + callable.package, + &mut fir_store, + fir_package_id, + )?; + + let compute_properties = qsc_rca::Analyzer::init(&fir_store, capabilities).analyze_all(); + validate_callable_capabilities( + package_store, + &fir_store, + &compute_properties, + qsc_fir::fir::StoreItemId { + package: qsc_lowerer::map_hir_package_to_fir(callable.package), + item: qsc_lowerer::map_hir_local_item_to_fir(callable.item), + }, + capabilities, + )?; + + Ok(CodegenFir { + fir_store, + fir_package_id, + compute_properties, + }) + } + + fn compile_to_codegen_fir( + sources: SourceMap, + language_features: LanguageFeatures, + capabilities: TargetCapabilityFlags, + package_store: &mut PackageStore, + dependencies: &Dependencies, + ) -> Result<(qsc_hir::hir::PackageId, CodegenFir), Vec> { + if capabilities == TargetCapabilityFlags::all() { + return Err(vec![Error::UnsupportedRuntimeCapabilities]); + } + + let (unit, errors) = crate::compile::compile( + package_store, + dependencies, + sources, + PackageType::Exe, + capabilities, + language_features, + ); + if !errors.is_empty() { + return Err(errors.iter().map(|e| Error::Compile(e.clone())).collect()); + } + + let package_id = package_store.insert(unit); + let prepared_fir = prepare_codegen_fir(package_store, package_id, capabilities)?; + Ok((package_id, prepared_fir)) + } + pub fn get_qir_from_ast( store: &mut PackageStore, dependencies: &Dependencies, @@ -47,33 +1415,15 @@ pub mod qir { } let package_id = store.insert(unit); - let (fir_store, fir_package_id) = qsc_passes::lower_hir_to_fir(store, package_id); - let package = fir_store.get(fir_package_id); - let entry = ProgramEntry { - exec_graph: package.entry_exec_graph.clone(), - expr: ( - fir_package_id, - package - .entry - .expect("package must have an entry expression"), - ) - .into(), - }; - - let compute_properties = PassContext::run_fir_passes_on_fir( - &fir_store, - fir_package_id, - capabilities, - ) - .map_err(|errors| { - let source_package = store.get(package_id).expect("package should be in store"); - errors - .iter() - .map(|e| Error::Pass(WithSource::from_map(&source_package.sources, e.clone()))) - .collect::>() - })?; + let prepared_fir = prepare_codegen_fir(store, package_id, capabilities)?; + let entry = entry_from_codegen_fir(&prepared_fir); + let CodegenFir { + fir_store, + compute_properties, + .. + } = prepared_fir; - fir_to_qir(&fir_store, capabilities, Some(compute_properties), &entry).map_err(|e| { + fir_to_qir(&fir_store, capabilities, &compute_properties, &entry).map_err(|e| { let source_package_id = match e.span() { Some(span) => span.package, None => package_id, @@ -95,18 +1445,24 @@ pub mod qir { mut package_store: PackageStore, dependencies: &Dependencies, ) -> Result, Vec> { - let (package_id, fir_store, entry, compute_properties) = compile_to_fir( + let (package_id, prepared_fir) = compile_to_codegen_fir( sources, language_features, capabilities, &mut package_store, dependencies, )?; + let entry = entry_from_codegen_fir(&prepared_fir); + let CodegenFir { + fir_store, + compute_properties, + .. + } = prepared_fir; let (raw, ssa) = fir_to_rir( &fir_store, capabilities, - Some(compute_properties), + &compute_properties, &entry, PartialEvalConfig { generate_debug_metadata: true, @@ -135,15 +1491,21 @@ pub mod qir { mut package_store: PackageStore, dependencies: &Dependencies, ) -> Result> { - let (package_id, fir_store, entry, compute_properties) = compile_to_fir( + let (package_id, prepared_fir) = compile_to_codegen_fir( sources, language_features, capabilities, &mut package_store, dependencies, )?; + let entry = entry_from_codegen_fir(&prepared_fir); + let CodegenFir { + fir_store, + compute_properties, + .. + } = prepared_fir; - fir_to_qir(&fir_store, capabilities, Some(compute_properties), &entry).map_err(|e| { + fir_to_qir(&fir_store, capabilities, &compute_properties, &entry).map_err(|e| { let source_package_id = match e.span() { Some(span) => span.package, None => package_id, @@ -157,63 +1519,4 @@ pub mod qir { ))] }) } - - fn compile_to_fir( - sources: SourceMap, - language_features: LanguageFeatures, - capabilities: TargetCapabilityFlags, - package_store: &mut PackageStore, - dependencies: &[(qsc_hir::hir::PackageId, Option>)], - ) -> Result< - ( - qsc_hir::hir::PackageId, - qsc_fir::fir::PackageStore, - ProgramEntry, - qsc_rca::PackageStoreComputeProperties, - ), - Vec, - > { - if capabilities == TargetCapabilityFlags::all() { - return Err(vec![Error::UnsupportedRuntimeCapabilities]); - } - let (unit, errors) = crate::compile::compile( - package_store, - dependencies, - sources, - PackageType::Exe, - capabilities, - language_features, - ); - if !errors.is_empty() { - return Err(errors.iter().map(|e| Error::Compile(e.clone())).collect()); - } - let package_id = package_store.insert(unit); - let (fir_store, fir_package_id) = qsc_passes::lower_hir_to_fir(package_store, package_id); - let package = fir_store.get(fir_package_id); - let entry = ProgramEntry { - exec_graph: package.entry_exec_graph.clone(), - expr: ( - fir_package_id, - package - .entry - .expect("package must have an entry expression"), - ) - .into(), - }; - let compute_properties = PassContext::run_fir_passes_on_fir( - &fir_store, - fir_package_id, - capabilities, - ) - .map_err(|errors| { - let source_package = package_store - .get(package_id) - .expect("package should be in store"); - errors - .iter() - .map(|e| Error::Pass(WithSource::from_map(&source_package.sources, e.clone()))) - .collect::>() - })?; - Ok((package_id, fir_store, entry, compute_properties)) - } } diff --git a/source/compiler/qsc/src/codegen/tests.rs b/source/compiler/qsc/src/codegen/tests.rs index 0c01f0ac62..53b12cb8a5 100644 --- a/source/compiler/qsc/src/codegen/tests.rs +++ b/source/compiler/qsc/src/codegen/tests.rs @@ -1,15 +1,72 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +#![allow(clippy::too_many_lines)] + +use std::sync::Arc; + +use std::rc::Rc; + use expect_test::expect; +use miette::Report; use qsc_data_structures::{ - language_features::LanguageFeatures, source::SourceMap, target::TargetCapabilityFlags, + functors::FunctorApp, language_features::LanguageFeatures, source::SourceMap, + target::TargetCapabilityFlags, +}; +use qsc_eval::val::Value; +use qsc_frontend::compile::parse_all; +use qsc_hir::hir::{ItemKind, PackageId}; +use rustc_hash::FxHashMap; + +use crate::codegen::qir::{ + get_qir, get_qir_from_ast, get_rir, prepare_codegen_fir_from_callable_args, }; -use crate::codegen::qir::get_qir; +fn format_interpret_errors(errors: Vec) -> String { + errors + .into_iter() + .map(|error| format!("{:?}", Report::new(error))) + .collect::>() + .join("\n\n") +} + +fn source_map_from_source(source: &str) -> SourceMap { + SourceMap::new([("test.qs".into(), source.into())], None) +} + +fn parse_source_to_ast(source: &str) -> (qsc_ast::ast::Package, SourceMap) { + let sources = source_map_from_source(source); + let language_features = LanguageFeatures::default(); + let (ast_package, errors) = parse_all(&sources, language_features); + + if errors.is_empty() { + (ast_package, sources) + } else { + let diagnostics = errors + .into_iter() + .map(|error| format!("{:?}", Report::new(error))) + .collect::>() + .join("\n\n"); + + panic!("Failed to parse AST test source:\n{diagnostics}"); + } +} fn compile_source_to_qir(source: &str, capabilities: TargetCapabilityFlags) -> String { - let sources = SourceMap::new([("test.qs".into(), source.into())], None); + match compile_source_to_qir_result(source, capabilities) { + Ok(qir) => qir, + Err(errors) => panic!( + "Failed to generate QIR for capabilities {capabilities:?}:\n{}", + format_interpret_errors(errors) + ), + } +} + +fn compile_source_to_qir_result( + source: &str, + capabilities: TargetCapabilityFlags, +) -> Result> { + let sources = source_map_from_source(source); let language_features = LanguageFeatures::default(); let (std_id, store) = crate::compile::package_store_with_stdlib(capabilities); @@ -20,7 +77,61 @@ fn compile_source_to_qir(source: &str, capabilities: TargetCapabilityFlags) -> S store, &[(std_id, None)], ) - .expect("Failed to generate QIR") +} + +fn compile_source_to_qir_from_ast(source: &str, capabilities: TargetCapabilityFlags) -> String { + match compile_source_to_qir_from_ast_result(source, capabilities) { + Ok(qir) => qir, + Err(errors) => panic!( + "Failed to generate QIR from AST for capabilities {capabilities:?}:\n{}", + format_interpret_errors(errors) + ), + } +} + +fn compile_source_to_qir_from_ast_result( + source: &str, + capabilities: TargetCapabilityFlags, +) -> Result> { + let (ast_package, sources) = parse_source_to_ast(source); + let (std_id, mut store) = crate::compile::package_store_with_stdlib(capabilities); + let dependencies: Vec<(PackageId, Option>)> = + vec![(PackageId::CORE, None), (std_id, None)]; + + get_qir_from_ast( + &mut store, + &dependencies, + ast_package, + sources, + capabilities, + ) +} + +fn compile_source_to_rir(source: &str, capabilities: TargetCapabilityFlags) -> Vec { + match compile_source_to_rir_result(source, capabilities) { + Ok(rir) => rir, + Err(errors) => panic!( + "Failed to generate RIR for capabilities {capabilities:?}:\n{}", + format_interpret_errors(errors) + ), + } +} + +fn compile_source_to_rir_result( + source: &str, + capabilities: TargetCapabilityFlags, +) -> Result, Vec> { + let sources = source_map_from_source(source); + let language_features = LanguageFeatures::default(); + + let (std_id, store) = crate::compile::package_store_with_stdlib(capabilities); + get_rir( + sources, + language_features, + capabilities, + store, + &[(std_id, None)], + ) } #[test] @@ -76,201 +187,2744 @@ fn code_with_errors_returns_errors() { } #[test] -fn code_returning_struct_from_entry_point_generates_errors() { - let source = "namespace Test { +fn unsupported_profile_patterns_return_pass_errors() { + let res = compile_source_to_qir_result( + indoc::indoc! {r#" + namespace Test { + @EntryPoint() + operation Main() : Int { + use q = Qubit(); + mutable x = 1; + if MResetZ(q) == One { + set x = 2; + } + x + } + } + "#}, + TargetCapabilityFlags::Adaptive, + ); + + let errors = res.expect_err("expected capability error"); + assert!(!errors.is_empty(), "expected at least one error"); + assert!( + errors + .iter() + .all(|error| matches!(error, crate::interpret::Error::Pass(_))), + "expected pass-derived codegen readiness errors, got {errors:?}" + ); + assert!( + errors.iter().any(|error| error + .to_string() + .contains("cannot use a dynamic integer value")), + "expected a dynamic integer capability diagnostic, got {errors:?}" + ); +} + +#[test] +fn qir_generation_succeeds_for_struct_copy_update() { + let source = r#" + namespace Test { @EntryPoint() - operation Main() : Std.Math.Complex { - new Std.Math.Complex { Real = 0.0, Imag = 0.0 } + operation Main() : Unit { + struct Point3d { X : Double, Y : Double, Z : Double } + + let point = new Point3d { X = 1.0, Y = 2.0, Z = 3.0 }; + let point2 = new Point3d { ...point, Z = 4.0 }; + let x : Double = point2.X; } - }"; - let sources = SourceMap::new([("test.qs".into(), source.into())], None); - let language_features = LanguageFeatures::default(); - let capabilities = TargetCapabilityFlags::empty(); - let (std_id, store) = crate::compile::package_store_with_stdlib(capabilities); + } + "#; + let qir = compile_source_to_qir(source, TargetCapabilityFlags::empty()); expect![[r#" - Err( - [ - Pass( - WithSource { - sources: [ - Source { - name: "test.qs", - contents: "namespace Test {\n @EntryPoint()\n operation Main() : Std.Math.Complex {\n new Std.Math.Complex { Real = 0.0, Imag = 0.0 }\n }\n }", - offset: 0, - }, - ], - error: CapabilitiesCk( - UseOfAdvancedOutput( - Span { - lo: 65, - hi: 69, - }, - ), - ), - }, - ), - ], - ) + %Result = type opaque + %Qubit = type opaque + + @0 = internal constant [4 x i8] c"0_t\00" + + define i64 @ENTRYPOINT__main() #0 { + block_0: + call void @__quantum__rt__initialize(i8* null) + call void @__quantum__rt__tuple_record_output(i64 0, i8* getelementptr inbounds ([4 x i8], [4 x i8]* @0, i64 0, i64 0)) + ret i64 0 + } + + declare void @__quantum__rt__initialize(i8*) + + declare void @__quantum__rt__tuple_record_output(i64, i8*) + + attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="base_profile" "required_num_qubits"="1" "required_num_results"="0" } + attributes #1 = { "irreversible" } + + ; module flags + + !llvm.module.flags = !{!0, !1, !2, !3} + + !0 = !{i32 1, !"qir_major_version", i32 1} + !1 = !{i32 7, !"qir_minor_version", i32 0} + !2 = !{i32 1, !"dynamic_qubit_management", i1 false} + !3 = !{i32 1, !"dynamic_result_management", i1 false} "#]] - .assert_debug_eq(&get_qir(sources, language_features, capabilities, store, &[(std_id, None)])); + .assert_eq(&qir); } #[test] -fn code_returning_struct_from_entry_expr_generates_errors() { - let source = ""; - let entry = "new Std.Math.Complex { Real = 0.0, Imag = 0.0 }"; - let sources = SourceMap::new([("test.qs".into(), source.into())], Some(entry.into())); - let language_features = LanguageFeatures::default(); - let capabilities = TargetCapabilityFlags::empty(); - let (std_id, store) = crate::compile::package_store_with_stdlib(capabilities); +fn deutsch_jozsa_sample_shape_generates_qir() { + let source = indoc::indoc! {r#" + namespace Test { + import Std.Diagnostics.*; + import Std.Math.*; + import Std.Measurement.*; + + @EntryPoint() + operation Main() : Bool[] { + let functionsToTest = [ + SimpleConstantBoolF, + SimpleBalancedBoolF, + ConstantBoolF, + BalancedBoolF + ]; + + mutable results = []; + for fn in functionsToTest { + let isConstant = DeutschJozsa(fn, 3); + set results += [isConstant]; + } + + return results; + } + + operation DeutschJozsa(Uf : ((Qubit[], Qubit) => Unit), n : Int) : Bool { + use queryRegister = Qubit[n]; + use target = Qubit(); + X(target); + H(target); + within { + for q in queryRegister { + H(q); + } + } apply { + Uf(queryRegister, target); + } + + mutable result = true; + for q in queryRegister { + if MResetZ(q) == One { + set result = false; + } + } + + Reset(target); + return result; + } + + operation SimpleConstantBoolF(args : Qubit[], target : Qubit) : Unit { + X(target); + } + + operation SimpleBalancedBoolF(args : Qubit[], target : Qubit) : Unit { + CX(args[0], target); + } + + operation ConstantBoolF(args : Qubit[], target : Qubit) : Unit { + for i in 0..(2^Length(args)) - 1 { + ApplyControlledOnInt(i, X, args, target); + } + } + + operation BalancedBoolF(args : Qubit[], target : Qubit) : Unit { + for i in 0..2..(2^Length(args)) - 1 { + ApplyControlledOnInt(i, X, args, target); + } + } + } + "#}; + + let qir = compile_source_to_qir( + source, + TargetCapabilityFlags::Adaptive + | TargetCapabilityFlags::IntegerComputations + | TargetCapabilityFlags::FloatingPointComputations, + ); expect![[r#" - Err( - [ - Pass( - WithSource { - sources: [ - Source { - name: "", - contents: "new Std.Math.Complex { Real = 0.0, Imag = 0.0 }", - offset: 0, - }, - ], - error: CapabilitiesCk( - UseOfAdvancedOutput( - Span { - lo: 0, - hi: 47, - }, - ), - ), - }, - ), - ], - ) + %Result = type opaque + %Qubit = type opaque + + @0 = internal constant [4 x i8] c"0_a\00" + @1 = internal constant [6 x i8] c"1_a0b\00" + @2 = internal constant [6 x i8] c"2_a1b\00" + @3 = internal constant [6 x i8] c"3_a2b\00" + @4 = internal constant [6 x i8] c"4_a3b\00" + + define i64 @ENTRYPOINT__main() #0 { + block_0: + call void @__quantum__rt__initialize(i8* null) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 3 to %Qubit*)) + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 3 to %Qubit*)) + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 1 to %Qubit*)) + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 2 to %Qubit*)) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 3 to %Qubit*)) + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 2 to %Qubit*)) + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 1 to %Qubit*)) + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Result* inttoptr (i64 0 to %Result*)) + %var_6 = call i1 @__quantum__rt__read_result(%Result* inttoptr (i64 0 to %Result*)) + br i1 %var_6, label %block_1, label %block_2 + block_1: + br label %block_2 + block_2: + %var_139 = phi i1 [true, %block_0], [false, %block_1] + call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 1 to %Qubit*), %Result* inttoptr (i64 1 to %Result*)) + %var_8 = call i1 @__quantum__rt__read_result(%Result* inttoptr (i64 1 to %Result*)) + br i1 %var_8, label %block_3, label %block_4 + block_3: + br label %block_4 + block_4: + %var_140 = phi i1 [%var_139, %block_2], [false, %block_3] + call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 2 to %Qubit*), %Result* inttoptr (i64 2 to %Result*)) + %var_10 = call i1 @__quantum__rt__read_result(%Result* inttoptr (i64 2 to %Result*)) + br i1 %var_10, label %block_5, label %block_6 + block_5: + br label %block_6 + block_6: + %var_141 = phi i1 [%var_140, %block_4], [false, %block_5] + call void @__quantum__qis__reset__body(%Qubit* inttoptr (i64 3 to %Qubit*)) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 3 to %Qubit*)) + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 3 to %Qubit*)) + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 1 to %Qubit*)) + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 2 to %Qubit*)) + call void @__quantum__qis__cx__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Qubit* inttoptr (i64 3 to %Qubit*)) + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 2 to %Qubit*)) + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 1 to %Qubit*)) + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Result* inttoptr (i64 3 to %Result*)) + %var_19 = call i1 @__quantum__rt__read_result(%Result* inttoptr (i64 3 to %Result*)) + br i1 %var_19, label %block_7, label %block_8 + block_7: + br label %block_8 + block_8: + %var_142 = phi i1 [true, %block_6], [false, %block_7] + call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 1 to %Qubit*), %Result* inttoptr (i64 4 to %Result*)) + %var_21 = call i1 @__quantum__rt__read_result(%Result* inttoptr (i64 4 to %Result*)) + br i1 %var_21, label %block_9, label %block_10 + block_9: + br label %block_10 + block_10: + %var_143 = phi i1 [%var_142, %block_8], [false, %block_9] + call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 2 to %Qubit*), %Result* inttoptr (i64 5 to %Result*)) + %var_23 = call i1 @__quantum__rt__read_result(%Result* inttoptr (i64 5 to %Result*)) + br i1 %var_23, label %block_11, label %block_12 + block_11: + br label %block_12 + block_12: + %var_144 = phi i1 [%var_143, %block_10], [false, %block_11] + call void @__quantum__qis__reset__body(%Qubit* inttoptr (i64 3 to %Qubit*)) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 3 to %Qubit*)) + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 3 to %Qubit*)) + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 1 to %Qubit*)) + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 2 to %Qubit*)) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 1 to %Qubit*)) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 2 to %Qubit*)) + call void @__quantum__qis__ccx__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Qubit* inttoptr (i64 1 to %Qubit*), %Qubit* inttoptr (i64 4 to %Qubit*)) + call void @__quantum__qis__ccx__body(%Qubit* inttoptr (i64 2 to %Qubit*), %Qubit* inttoptr (i64 4 to %Qubit*), %Qubit* inttoptr (i64 3 to %Qubit*)) + call void @__quantum__qis__ccx__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Qubit* inttoptr (i64 1 to %Qubit*), %Qubit* inttoptr (i64 4 to %Qubit*)) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 2 to %Qubit*)) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 1 to %Qubit*)) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 1 to %Qubit*)) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 2 to %Qubit*)) + call void @__quantum__qis__ccx__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Qubit* inttoptr (i64 1 to %Qubit*), %Qubit* inttoptr (i64 4 to %Qubit*)) + call void @__quantum__qis__ccx__body(%Qubit* inttoptr (i64 2 to %Qubit*), %Qubit* inttoptr (i64 4 to %Qubit*), %Qubit* inttoptr (i64 3 to %Qubit*)) + call void @__quantum__qis__ccx__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Qubit* inttoptr (i64 1 to %Qubit*), %Qubit* inttoptr (i64 4 to %Qubit*)) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 2 to %Qubit*)) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 1 to %Qubit*)) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 2 to %Qubit*)) + call void @__quantum__qis__ccx__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Qubit* inttoptr (i64 1 to %Qubit*), %Qubit* inttoptr (i64 4 to %Qubit*)) + call void @__quantum__qis__ccx__body(%Qubit* inttoptr (i64 2 to %Qubit*), %Qubit* inttoptr (i64 4 to %Qubit*), %Qubit* inttoptr (i64 3 to %Qubit*)) + call void @__quantum__qis__ccx__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Qubit* inttoptr (i64 1 to %Qubit*), %Qubit* inttoptr (i64 4 to %Qubit*)) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 2 to %Qubit*)) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 2 to %Qubit*)) + call void @__quantum__qis__ccx__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Qubit* inttoptr (i64 1 to %Qubit*), %Qubit* inttoptr (i64 4 to %Qubit*)) + call void @__quantum__qis__ccx__body(%Qubit* inttoptr (i64 2 to %Qubit*), %Qubit* inttoptr (i64 4 to %Qubit*), %Qubit* inttoptr (i64 3 to %Qubit*)) + call void @__quantum__qis__ccx__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Qubit* inttoptr (i64 1 to %Qubit*), %Qubit* inttoptr (i64 4 to %Qubit*)) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 2 to %Qubit*)) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 1 to %Qubit*)) + call void @__quantum__qis__ccx__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Qubit* inttoptr (i64 1 to %Qubit*), %Qubit* inttoptr (i64 4 to %Qubit*)) + call void @__quantum__qis__ccx__body(%Qubit* inttoptr (i64 2 to %Qubit*), %Qubit* inttoptr (i64 4 to %Qubit*), %Qubit* inttoptr (i64 3 to %Qubit*)) + call void @__quantum__qis__ccx__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Qubit* inttoptr (i64 1 to %Qubit*), %Qubit* inttoptr (i64 4 to %Qubit*)) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 1 to %Qubit*)) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 1 to %Qubit*)) + call void @__quantum__qis__ccx__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Qubit* inttoptr (i64 1 to %Qubit*), %Qubit* inttoptr (i64 4 to %Qubit*)) + call void @__quantum__qis__ccx__body(%Qubit* inttoptr (i64 2 to %Qubit*), %Qubit* inttoptr (i64 4 to %Qubit*), %Qubit* inttoptr (i64 3 to %Qubit*)) + call void @__quantum__qis__ccx__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Qubit* inttoptr (i64 1 to %Qubit*), %Qubit* inttoptr (i64 4 to %Qubit*)) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 1 to %Qubit*)) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__ccx__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Qubit* inttoptr (i64 1 to %Qubit*), %Qubit* inttoptr (i64 4 to %Qubit*)) + call void @__quantum__qis__ccx__body(%Qubit* inttoptr (i64 2 to %Qubit*), %Qubit* inttoptr (i64 4 to %Qubit*), %Qubit* inttoptr (i64 3 to %Qubit*)) + call void @__quantum__qis__ccx__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Qubit* inttoptr (i64 1 to %Qubit*), %Qubit* inttoptr (i64 4 to %Qubit*)) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__ccx__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Qubit* inttoptr (i64 1 to %Qubit*), %Qubit* inttoptr (i64 4 to %Qubit*)) + call void @__quantum__qis__ccx__body(%Qubit* inttoptr (i64 2 to %Qubit*), %Qubit* inttoptr (i64 4 to %Qubit*), %Qubit* inttoptr (i64 3 to %Qubit*)) + call void @__quantum__qis__ccx__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Qubit* inttoptr (i64 1 to %Qubit*), %Qubit* inttoptr (i64 4 to %Qubit*)) + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 2 to %Qubit*)) + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 1 to %Qubit*)) + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Result* inttoptr (i64 6 to %Result*)) + %var_89 = call i1 @__quantum__rt__read_result(%Result* inttoptr (i64 6 to %Result*)) + br i1 %var_89, label %block_13, label %block_14 + block_13: + br label %block_14 + block_14: + %var_145 = phi i1 [true, %block_12], [false, %block_13] + call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 1 to %Qubit*), %Result* inttoptr (i64 7 to %Result*)) + %var_91 = call i1 @__quantum__rt__read_result(%Result* inttoptr (i64 7 to %Result*)) + br i1 %var_91, label %block_15, label %block_16 + block_15: + br label %block_16 + block_16: + %var_146 = phi i1 [%var_145, %block_14], [false, %block_15] + call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 2 to %Qubit*), %Result* inttoptr (i64 8 to %Result*)) + %var_93 = call i1 @__quantum__rt__read_result(%Result* inttoptr (i64 8 to %Result*)) + br i1 %var_93, label %block_17, label %block_18 + block_17: + br label %block_18 + block_18: + %var_147 = phi i1 [%var_146, %block_16], [false, %block_17] + call void @__quantum__qis__reset__body(%Qubit* inttoptr (i64 3 to %Qubit*)) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 3 to %Qubit*)) + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 3 to %Qubit*)) + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 1 to %Qubit*)) + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 2 to %Qubit*)) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 1 to %Qubit*)) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 2 to %Qubit*)) + call void @__quantum__qis__ccx__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Qubit* inttoptr (i64 1 to %Qubit*), %Qubit* inttoptr (i64 4 to %Qubit*)) + call void @__quantum__qis__ccx__body(%Qubit* inttoptr (i64 2 to %Qubit*), %Qubit* inttoptr (i64 4 to %Qubit*), %Qubit* inttoptr (i64 3 to %Qubit*)) + call void @__quantum__qis__ccx__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Qubit* inttoptr (i64 1 to %Qubit*), %Qubit* inttoptr (i64 4 to %Qubit*)) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 2 to %Qubit*)) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 1 to %Qubit*)) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 2 to %Qubit*)) + call void @__quantum__qis__ccx__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Qubit* inttoptr (i64 1 to %Qubit*), %Qubit* inttoptr (i64 4 to %Qubit*)) + call void @__quantum__qis__ccx__body(%Qubit* inttoptr (i64 2 to %Qubit*), %Qubit* inttoptr (i64 4 to %Qubit*), %Qubit* inttoptr (i64 3 to %Qubit*)) + call void @__quantum__qis__ccx__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Qubit* inttoptr (i64 1 to %Qubit*), %Qubit* inttoptr (i64 4 to %Qubit*)) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 2 to %Qubit*)) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 1 to %Qubit*)) + call void @__quantum__qis__ccx__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Qubit* inttoptr (i64 1 to %Qubit*), %Qubit* inttoptr (i64 4 to %Qubit*)) + call void @__quantum__qis__ccx__body(%Qubit* inttoptr (i64 2 to %Qubit*), %Qubit* inttoptr (i64 4 to %Qubit*), %Qubit* inttoptr (i64 3 to %Qubit*)) + call void @__quantum__qis__ccx__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Qubit* inttoptr (i64 1 to %Qubit*), %Qubit* inttoptr (i64 4 to %Qubit*)) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 1 to %Qubit*)) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__ccx__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Qubit* inttoptr (i64 1 to %Qubit*), %Qubit* inttoptr (i64 4 to %Qubit*)) + call void @__quantum__qis__ccx__body(%Qubit* inttoptr (i64 2 to %Qubit*), %Qubit* inttoptr (i64 4 to %Qubit*), %Qubit* inttoptr (i64 3 to %Qubit*)) + call void @__quantum__qis__ccx__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Qubit* inttoptr (i64 1 to %Qubit*), %Qubit* inttoptr (i64 4 to %Qubit*)) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 2 to %Qubit*)) + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 1 to %Qubit*)) + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Result* inttoptr (i64 9 to %Result*)) + %var_131 = call i1 @__quantum__rt__read_result(%Result* inttoptr (i64 9 to %Result*)) + br i1 %var_131, label %block_19, label %block_20 + block_19: + br label %block_20 + block_20: + %var_148 = phi i1 [true, %block_18], [false, %block_19] + call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 1 to %Qubit*), %Result* inttoptr (i64 10 to %Result*)) + %var_133 = call i1 @__quantum__rt__read_result(%Result* inttoptr (i64 10 to %Result*)) + br i1 %var_133, label %block_21, label %block_22 + block_21: + br label %block_22 + block_22: + %var_149 = phi i1 [%var_148, %block_20], [false, %block_21] + call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 2 to %Qubit*), %Result* inttoptr (i64 11 to %Result*)) + %var_135 = call i1 @__quantum__rt__read_result(%Result* inttoptr (i64 11 to %Result*)) + br i1 %var_135, label %block_23, label %block_24 + block_23: + br label %block_24 + block_24: + %var_150 = phi i1 [%var_149, %block_22], [false, %block_23] + call void @__quantum__qis__reset__body(%Qubit* inttoptr (i64 3 to %Qubit*)) + call void @__quantum__rt__array_record_output(i64 4, i8* getelementptr inbounds ([4 x i8], [4 x i8]* @0, i64 0, i64 0)) + call void @__quantum__rt__bool_record_output(i1 %var_141, i8* getelementptr inbounds ([6 x i8], [6 x i8]* @1, i64 0, i64 0)) + call void @__quantum__rt__bool_record_output(i1 %var_144, i8* getelementptr inbounds ([6 x i8], [6 x i8]* @2, i64 0, i64 0)) + call void @__quantum__rt__bool_record_output(i1 %var_147, i8* getelementptr inbounds ([6 x i8], [6 x i8]* @3, i64 0, i64 0)) + call void @__quantum__rt__bool_record_output(i1 %var_150, i8* getelementptr inbounds ([6 x i8], [6 x i8]* @4, i64 0, i64 0)) + ret i64 0 + } + + declare void @__quantum__rt__initialize(i8*) + + declare void @__quantum__qis__x__body(%Qubit*) + + declare void @__quantum__qis__h__body(%Qubit*) + + declare void @__quantum__qis__mresetz__body(%Qubit*, %Result*) #1 + + declare i1 @__quantum__rt__read_result(%Result*) + + declare void @__quantum__qis__reset__body(%Qubit*) #1 + + declare void @__quantum__qis__cx__body(%Qubit*, %Qubit*) + + declare void @__quantum__qis__ccx__body(%Qubit*, %Qubit*, %Qubit*) + + declare void @__quantum__rt__array_record_output(i64, i8*) + + declare void @__quantum__rt__bool_record_output(i1, i8*) + + attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="5" "required_num_results"="12" } + attributes #1 = { "irreversible" } + + ; module flags + + !llvm.module.flags = !{!0, !1, !2, !3, !4, !5} + + !0 = !{i32 1, !"qir_major_version", i32 1} + !1 = !{i32 7, !"qir_minor_version", i32 0} + !2 = !{i32 1, !"dynamic_qubit_management", i1 false} + !3 = !{i32 1, !"dynamic_result_management", i1 false} + !4 = !{i32 5, !"int_computations", !{!"i64"}} + !5 = !{i32 5, !"float_computations", !{!"double"}} "#]] - .assert_debug_eq(&get_qir( - sources, - language_features, - capabilities, - store, - &[(std_id, None)], - )); + .assert_eq(&qir); } #[test] -fn code_returning_struct_from_block_entry_expr_generates_errors() { - let source = ""; - let entry = "{ new Std.Math.Complex { Real = 0.0, Imag = 0.0 } }"; - let sources = SourceMap::new([("test.qs".into(), source.into())], Some(entry.into())); - let language_features = LanguageFeatures::default(); - let capabilities = TargetCapabilityFlags::empty(); - let (std_id, store) = crate::compile::package_store_with_stdlib(capabilities); +fn simple_phase_estimation_sample_shape_generates_qir() { + let source = indoc::indoc! {r#" + namespace Test { + operation Main() : Result[] { + use state = Qubit(); + use phase = Qubit[3]; + + X(state); + + let oracle = ApplyOperationPowerCA(_, qs => U(qs[0]), _); + ApplyQPE(oracle, [state], phase); + + let results = MeasureEachZ(phase); + + Reset(state); + ResetAll(phase); + + Std.Arrays.Reversed(results) + } + + operation U(q : Qubit) : Unit is Ctl + Adj { + Rz(Std.Math.PI() / 3.0, q); + } + } + "#}; + + let qir = compile_source_to_qir( + source, + TargetCapabilityFlags::Adaptive + | TargetCapabilityFlags::IntegerComputations + | TargetCapabilityFlags::FloatingPointComputations, + ); expect![[r#" - Err( - [ - Pass( - WithSource { - sources: [ - Source { - name: "", - contents: "{ new Std.Math.Complex { Real = 0.0, Imag = 0.0 } }", - offset: 0, - }, - ], - error: CapabilitiesCk( - UseOfAdvancedOutput( - Span { - lo: 0, - hi: 51, - }, - ), - ), - }, - ), - ], - ) + %Result = type opaque + %Qubit = type opaque + + @0 = internal constant [4 x i8] c"0_a\00" + @1 = internal constant [6 x i8] c"1_a0r\00" + @2 = internal constant [6 x i8] c"2_a1r\00" + @3 = internal constant [6 x i8] c"3_a2r\00" + + define i64 @ENTRYPOINT__main() #0 { + block_0: + call void @__quantum__rt__initialize(i8* null) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 1 to %Qubit*)) + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 2 to %Qubit*)) + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 3 to %Qubit*)) + call void @__quantum__qis__rz__body(double 0.5235987755982988, %Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__cx__body(%Qubit* inttoptr (i64 1 to %Qubit*), %Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__rz__body(double -0.5235987755982988, %Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__cx__body(%Qubit* inttoptr (i64 1 to %Qubit*), %Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__rz__body(double 0.5235987755982988, %Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__cx__body(%Qubit* inttoptr (i64 1 to %Qubit*), %Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__rz__body(double -0.5235987755982988, %Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__cx__body(%Qubit* inttoptr (i64 1 to %Qubit*), %Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__rz__body(double 0.5235987755982988, %Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__cx__body(%Qubit* inttoptr (i64 1 to %Qubit*), %Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__rz__body(double -0.5235987755982988, %Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__cx__body(%Qubit* inttoptr (i64 1 to %Qubit*), %Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__rz__body(double 0.5235987755982988, %Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__cx__body(%Qubit* inttoptr (i64 1 to %Qubit*), %Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__rz__body(double -0.5235987755982988, %Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__cx__body(%Qubit* inttoptr (i64 1 to %Qubit*), %Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__rz__body(double 0.5235987755982988, %Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__cx__body(%Qubit* inttoptr (i64 2 to %Qubit*), %Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__rz__body(double -0.5235987755982988, %Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__cx__body(%Qubit* inttoptr (i64 2 to %Qubit*), %Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__rz__body(double 0.5235987755982988, %Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__cx__body(%Qubit* inttoptr (i64 2 to %Qubit*), %Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__rz__body(double -0.5235987755982988, %Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__cx__body(%Qubit* inttoptr (i64 2 to %Qubit*), %Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__rz__body(double 0.5235987755982988, %Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__cx__body(%Qubit* inttoptr (i64 3 to %Qubit*), %Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__rz__body(double -0.5235987755982988, %Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__cx__body(%Qubit* inttoptr (i64 3 to %Qubit*), %Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 1 to %Qubit*)) + call void @__quantum__qis__rz__body(double -0.7853981633974483, %Qubit* inttoptr (i64 2 to %Qubit*)) + call void @__quantum__qis__rz__body(double -0.7853981633974483, %Qubit* inttoptr (i64 1 to %Qubit*)) + call void @__quantum__qis__cx__body(%Qubit* inttoptr (i64 2 to %Qubit*), %Qubit* inttoptr (i64 1 to %Qubit*)) + call void @__quantum__qis__rz__body(double 0.7853981633974483, %Qubit* inttoptr (i64 1 to %Qubit*)) + call void @__quantum__qis__cx__body(%Qubit* inttoptr (i64 2 to %Qubit*), %Qubit* inttoptr (i64 1 to %Qubit*)) + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 2 to %Qubit*)) + call void @__quantum__qis__rz__body(double -0.39269908169872414, %Qubit* inttoptr (i64 3 to %Qubit*)) + call void @__quantum__qis__rz__body(double -0.39269908169872414, %Qubit* inttoptr (i64 1 to %Qubit*)) + call void @__quantum__qis__cx__body(%Qubit* inttoptr (i64 3 to %Qubit*), %Qubit* inttoptr (i64 1 to %Qubit*)) + call void @__quantum__qis__rz__body(double 0.39269908169872414, %Qubit* inttoptr (i64 1 to %Qubit*)) + call void @__quantum__qis__cx__body(%Qubit* inttoptr (i64 3 to %Qubit*), %Qubit* inttoptr (i64 1 to %Qubit*)) + call void @__quantum__qis__rz__body(double -0.7853981633974483, %Qubit* inttoptr (i64 3 to %Qubit*)) + call void @__quantum__qis__rz__body(double -0.7853981633974483, %Qubit* inttoptr (i64 2 to %Qubit*)) + call void @__quantum__qis__cx__body(%Qubit* inttoptr (i64 3 to %Qubit*), %Qubit* inttoptr (i64 2 to %Qubit*)) + call void @__quantum__qis__rz__body(double 0.7853981633974483, %Qubit* inttoptr (i64 2 to %Qubit*)) + call void @__quantum__qis__cx__body(%Qubit* inttoptr (i64 3 to %Qubit*), %Qubit* inttoptr (i64 2 to %Qubit*)) + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 3 to %Qubit*)) + call void @__quantum__qis__m__body(%Qubit* inttoptr (i64 1 to %Qubit*), %Result* inttoptr (i64 0 to %Result*)) + call void @__quantum__qis__m__body(%Qubit* inttoptr (i64 2 to %Qubit*), %Result* inttoptr (i64 1 to %Result*)) + call void @__quantum__qis__m__body(%Qubit* inttoptr (i64 3 to %Qubit*), %Result* inttoptr (i64 2 to %Result*)) + call void @__quantum__qis__reset__body(%Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__reset__body(%Qubit* inttoptr (i64 1 to %Qubit*)) + call void @__quantum__qis__reset__body(%Qubit* inttoptr (i64 2 to %Qubit*)) + call void @__quantum__qis__reset__body(%Qubit* inttoptr (i64 3 to %Qubit*)) + call void @__quantum__rt__array_record_output(i64 3, i8* getelementptr inbounds ([4 x i8], [4 x i8]* @0, i64 0, i64 0)) + call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 2 to %Result*), i8* getelementptr inbounds ([6 x i8], [6 x i8]* @1, i64 0, i64 0)) + call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 1 to %Result*), i8* getelementptr inbounds ([6 x i8], [6 x i8]* @2, i64 0, i64 0)) + call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 0 to %Result*), i8* getelementptr inbounds ([6 x i8], [6 x i8]* @3, i64 0, i64 0)) + ret i64 0 + } + + declare void @__quantum__rt__initialize(i8*) + + declare void @__quantum__qis__x__body(%Qubit*) + + declare void @__quantum__qis__h__body(%Qubit*) + + declare void @__quantum__qis__rz__body(double, %Qubit*) + + declare void @__quantum__qis__cx__body(%Qubit*, %Qubit*) + + declare void @__quantum__qis__m__body(%Qubit*, %Result*) #1 + + declare void @__quantum__qis__reset__body(%Qubit*) #1 + + declare void @__quantum__rt__array_record_output(i64, i8*) + + declare void @__quantum__rt__result_record_output(%Result*, i8*) + + attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="4" "required_num_results"="3" } + attributes #1 = { "irreversible" } + + ; module flags + + !llvm.module.flags = !{!0, !1, !2, !3, !4, !5} + + !0 = !{i32 1, !"qir_major_version", i32 1} + !1 = !{i32 7, !"qir_minor_version", i32 0} + !2 = !{i32 1, !"dynamic_qubit_management", i1 false} + !3 = !{i32 1, !"dynamic_result_management", i1 false} + !4 = !{i32 5, !"int_computations", !{!"i64"}} + !5 = !{i32 5, !"float_computations", !{!"double"}} "#]] - .assert_debug_eq(&get_qir( - sources, - language_features, - capabilities, - store, - &[(std_id, None)], - )); + .assert_eq(&qir); } #[test] -fn code_returning_struct_from_if_entry_expr_generates_errors() { - let source = ""; - let entry = "if (true) { new Std.Math.Complex { Real = 0.0, Imag = 0.0 } } else { fail \"shouldn't get here\" }"; - let sources = SourceMap::new([("test.qs".into(), source.into())], Some(entry.into())); - let language_features = LanguageFeatures::default(); - let capabilities = TargetCapabilityFlags::empty(); - let (std_id, store) = crate::compile::package_store_with_stdlib(capabilities); +fn explicit_return_tuple_keeps_dynamic_integer_output() { + let source = indoc::indoc! {r#" + namespace Test { + import Std.Measurement.*; + + @EntryPoint() + operation Main() : (Int, Bool) { + use q = Qubit(); + mutable a = 0; + if MResetZ(q) == Zero { + set a = 1; + } else { + set a = 2; + } + + use p = Qubit(); + return (a, MResetZ(p) == One); + } + } + "#}; + + let qir = compile_source_to_qir( + source, + TargetCapabilityFlags::Adaptive | TargetCapabilityFlags::IntegerComputations, + ); expect![[r#" - Err( - [ - Pass( - WithSource { - sources: [ - Source { - name: "", - contents: "if (true) { new Std.Math.Complex { Real = 0.0, Imag = 0.0 } } else { fail \"shouldn't get here\" }", - offset: 0, - }, - ], - error: CapabilitiesCk( - UseOfAdvancedOutput( - Span { - lo: 0, - hi: 96, - }, - ), - ), - }, - ), - ], - ) + %Result = type opaque + %Qubit = type opaque + + @0 = internal constant [4 x i8] c"0_t\00" + @1 = internal constant [6 x i8] c"1_t0i\00" + @2 = internal constant [6 x i8] c"2_t1b\00" + + define i64 @ENTRYPOINT__main() #0 { + block_0: + call void @__quantum__rt__initialize(i8* null) + call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Result* inttoptr (i64 0 to %Result*)) + %var_1 = call i1 @__quantum__rt__read_result(%Result* inttoptr (i64 0 to %Result*)) + %var_2 = icmp eq i1 %var_1, false + br i1 %var_2, label %block_1, label %block_2 + block_1: + br label %block_3 + block_2: + br label %block_3 + block_3: + %var_5 = phi i64 [1, %block_1], [2, %block_2] + call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 1 to %Qubit*), %Result* inttoptr (i64 1 to %Result*)) + %var_3 = call i1 @__quantum__rt__read_result(%Result* inttoptr (i64 1 to %Result*)) + call void @__quantum__rt__tuple_record_output(i64 2, i8* getelementptr inbounds ([4 x i8], [4 x i8]* @0, i64 0, i64 0)) + call void @__quantum__rt__int_record_output(i64 %var_5, i8* getelementptr inbounds ([6 x i8], [6 x i8]* @1, i64 0, i64 0)) + call void @__quantum__rt__bool_record_output(i1 %var_3, i8* getelementptr inbounds ([6 x i8], [6 x i8]* @2, i64 0, i64 0)) + ret i64 0 + } + + declare void @__quantum__rt__initialize(i8*) + + declare void @__quantum__qis__mresetz__body(%Qubit*, %Result*) #1 + + declare i1 @__quantum__rt__read_result(%Result*) + + declare void @__quantum__rt__tuple_record_output(i64, i8*) + + declare void @__quantum__rt__int_record_output(i64, i8*) + + declare void @__quantum__rt__bool_record_output(i1, i8*) + + attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="2" "required_num_results"="2" } + attributes #1 = { "irreversible" } + + ; module flags + + !llvm.module.flags = !{!0, !1, !2, !3, !4} + + !0 = !{i32 1, !"qir_major_version", i32 1} + !1 = !{i32 7, !"qir_minor_version", i32 0} + !2 = !{i32 1, !"dynamic_qubit_management", i1 false} + !3 = !{i32 1, !"dynamic_result_management", i1 false} + !4 = !{i32 5, !"int_computations", !{!"i64"}} "#]] - .assert_debug_eq(&get_qir( - sources, - language_features, - capabilities, - store, - &[(std_id, None)], - )); + .assert_eq(&qir); } -mod base_profile { - use expect_test::expect; - use qsc_data_structures::target::TargetCapabilityFlags; +#[test] +fn result_array_helper_return_survives_adaptive_codegen_prep() { + let source = indoc::indoc! {r#" + namespace Test { + import Std.Measurement.*; - use super::compile_source_to_qir; - static CAPABILITIES: std::sync::LazyLock = - std::sync::LazyLock::new(TargetCapabilityFlags::empty); + @EntryPoint() + operation Main() : Result[] { + use register = Qubit[2]; + return MResetZ2Register(register); + } + + operation MResetZ2Register(register : Qubit[]) : Result[] { + return [MResetZ(register[0]), MResetZ(register[1])]; + } + } + "#}; + + let qir = compile_source_to_qir( + source, + TargetCapabilityFlags::Adaptive | TargetCapabilityFlags::IntegerComputations, + ); + + expect![[r#" + %Result = type opaque + %Qubit = type opaque + + @0 = internal constant [4 x i8] c"0_a\00" + @1 = internal constant [6 x i8] c"1_a0r\00" + @2 = internal constant [6 x i8] c"2_a1r\00" + + define i64 @ENTRYPOINT__main() #0 { + block_0: + call void @__quantum__rt__initialize(i8* null) + call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Result* inttoptr (i64 0 to %Result*)) + call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 1 to %Qubit*), %Result* inttoptr (i64 1 to %Result*)) + call void @__quantum__rt__array_record_output(i64 2, i8* getelementptr inbounds ([4 x i8], [4 x i8]* @0, i64 0, i64 0)) + call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 0 to %Result*), i8* getelementptr inbounds ([6 x i8], [6 x i8]* @1, i64 0, i64 0)) + call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 1 to %Result*), i8* getelementptr inbounds ([6 x i8], [6 x i8]* @2, i64 0, i64 0)) + ret i64 0 + } + + declare void @__quantum__rt__initialize(i8*) + + declare void @__quantum__qis__mresetz__body(%Qubit*, %Result*) #1 + + declare void @__quantum__rt__array_record_output(i64, i8*) + + declare void @__quantum__rt__result_record_output(%Result*, i8*) + + attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="2" "required_num_results"="2" } + attributes #1 = { "irreversible" } + + ; module flags + + !llvm.module.flags = !{!0, !1, !2, !3, !4} + + !0 = !{i32 1, !"qir_major_version", i32 1} + !1 = !{i32 7, !"qir_minor_version", i32 0} + !2 = !{i32 1, !"dynamic_qubit_management", i1 false} + !3 = !{i32 1, !"dynamic_result_management", i1 false} + !4 = !{i32 5, !"int_computations", !{!"i64"}} + "#]] + .assert_eq(&qir); +} + +#[test] +fn higher_order_closure_captures_are_threaded_into_specialized_calls() { + let source = indoc::indoc! {r#" + namespace Test { + import Std.Canon.*; + import Std.Measurement.*; + + operation ApplyOp(op : (Qubit[] => Unit), register : Qubit[]) : Result[] { + op(register); + return MResetEachZ(register); + } - #[test] - fn simple() { - let source = "namespace Test { - import Std.Math.*; - open QIR.Intrinsic; @EntryPoint() - operation Main() : Result { - use q = Qubit(); - let pi_over_two = 4.0 / 2.0; - __quantum__qis__rz__body(pi_over_two, q); - mutable some_angle = ArcSin(0.0); - __quantum__qis__rz__body(some_angle, q); - set some_angle = ArcCos(-1.0) / PI(); - __quantum__qis__rz__body(some_angle, q); - __quantum__qis__mresetz__body(q) + operation Main() : Result[] { + use register = Qubit[2]; + return ApplyOp(register => Shifted(1, register), register); + } + + operation Shifted(shift : Int, register : Qubit[]) : Unit { + ApplyXorInPlace(shift, register); + } + } + "#}; + + let qir = compile_source_to_qir( + source, + TargetCapabilityFlags::Adaptive | TargetCapabilityFlags::IntegerComputations, + ); + + expect![[r#" + %Result = type opaque + %Qubit = type opaque + + @0 = internal constant [4 x i8] c"0_a\00" + @1 = internal constant [6 x i8] c"1_a0r\00" + @2 = internal constant [6 x i8] c"2_a1r\00" + + define i64 @ENTRYPOINT__main() #0 { + block_0: + call void @__quantum__rt__initialize(i8* null) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Result* inttoptr (i64 0 to %Result*)) + call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 1 to %Qubit*), %Result* inttoptr (i64 1 to %Result*)) + call void @__quantum__rt__array_record_output(i64 2, i8* getelementptr inbounds ([4 x i8], [4 x i8]* @0, i64 0, i64 0)) + call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 0 to %Result*), i8* getelementptr inbounds ([6 x i8], [6 x i8]* @1, i64 0, i64 0)) + call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 1 to %Result*), i8* getelementptr inbounds ([6 x i8], [6 x i8]* @2, i64 0, i64 0)) + ret i64 0 + } + + declare void @__quantum__rt__initialize(i8*) + + declare void @__quantum__qis__x__body(%Qubit*) + + declare void @__quantum__qis__mresetz__body(%Qubit*, %Result*) #1 + + declare void @__quantum__rt__array_record_output(i64, i8*) + + declare void @__quantum__rt__result_record_output(%Result*, i8*) + + attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="2" "required_num_results"="2" } + attributes #1 = { "irreversible" } + + ; module flags + + !llvm.module.flags = !{!0, !1, !2, !3, !4} + + !0 = !{i32 1, !"qir_major_version", i32 1} + !1 = !{i32 7, !"qir_minor_version", i32 0} + !2 = !{i32 1, !"dynamic_qubit_management", i1 false} + !3 = !{i32 1, !"dynamic_result_management", i1 false} + !4 = !{i32 5, !"int_computations", !{!"i64"}} + "#]] + .assert_eq(&qir); +} + +#[test] +fn two_callable_hof_closure_preserves_array_arg_threading() { + let source = indoc::indoc! {r#" + namespace Test { + import Std.Arrays.*; + import Std.Canon.*; + import Std.Convert.*; + import Std.Measurement.*; + + operation Outer(Ufstar : (Qubit[] => Unit), Ug : (Qubit[] => Unit), n : Int) : Result[] { + use qubits = Qubit[n]; + Ug(qubits); + return MResetEachZ(qubits); + } + + operation Empty(register : Qubit[]) : Unit { + } + + operation ShiftedSimple(shift : Int, register : Qubit[]) : Unit { + ApplyXorInPlace(shift, register); + } + + @EntryPoint() + operation Main() : Result[] { + let bits = [true, false]; + let shift = BoolArrayAsInt(bits); + let n = Length(bits); + return Outer(Empty, register => ShiftedSimple(shift, register), n); + } + } + "#}; + + let qir = compile_source_to_qir( + source, + TargetCapabilityFlags::Adaptive | TargetCapabilityFlags::IntegerComputations, + ); + + expect![[r#" + %Result = type opaque + %Qubit = type opaque + + @0 = internal constant [4 x i8] c"0_a\00" + @1 = internal constant [6 x i8] c"1_a0r\00" + @2 = internal constant [6 x i8] c"2_a1r\00" + + define i64 @ENTRYPOINT__main() #0 { + block_0: + call void @__quantum__rt__initialize(i8* null) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Result* inttoptr (i64 0 to %Result*)) + call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 1 to %Qubit*), %Result* inttoptr (i64 1 to %Result*)) + call void @__quantum__rt__array_record_output(i64 2, i8* getelementptr inbounds ([4 x i8], [4 x i8]* @0, i64 0, i64 0)) + call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 0 to %Result*), i8* getelementptr inbounds ([6 x i8], [6 x i8]* @1, i64 0, i64 0)) + call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 1 to %Result*), i8* getelementptr inbounds ([6 x i8], [6 x i8]* @2, i64 0, i64 0)) + ret i64 0 + } + + declare void @__quantum__rt__initialize(i8*) + + declare void @__quantum__qis__x__body(%Qubit*) + + declare void @__quantum__qis__mresetz__body(%Qubit*, %Result*) #1 + + declare void @__quantum__rt__array_record_output(i64, i8*) + + declare void @__quantum__rt__result_record_output(%Result*, i8*) + + attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="2" "required_num_results"="2" } + attributes #1 = { "irreversible" } + + ; module flags + + !llvm.module.flags = !{!0, !1, !2, !3, !4} + + !0 = !{i32 1, !"qir_major_version", i32 1} + !1 = !{i32 7, !"qir_minor_version", i32 0} + !2 = !{i32 1, !"dynamic_qubit_management", i1 false} + !3 = !{i32 1, !"dynamic_result_management", i1 false} + !4 = !{i32 5, !"int_computations", !{!"i64"}} + "#]] + .assert_eq(&qir); +} + +#[test] +fn callable_args_with_arrow_input_survives_dce() { + let source = indoc::indoc! {r#" + namespace Test { + operation ApplyOp(op : Qubit => Unit) : Result { + use q = Qubit(); + op(q); + MResetZ(q) + } + operation MyOp(q : Qubit) : Unit { H(q); } + } + "#}; + + let capabilities = TargetCapabilityFlags::Adaptive + | TargetCapabilityFlags::IntegerComputations + | TargetCapabilityFlags::FloatingPointComputations; + + let sources = source_map_from_source(source); + let language_features = LanguageFeatures::default(); + let (std_id, mut store) = crate::compile::package_store_with_stdlib(capabilities); + let dependencies: Vec<(PackageId, Option>)> = vec![(std_id, None)]; + + let (unit, errors) = crate::compile::compile( + &store, + &dependencies, + sources, + qsc_passes::PackageType::Lib, + capabilities, + language_features, + ); + assert!(errors.is_empty(), "compilation failed: {errors:?}"); + let package_id = store.insert(unit); + + // Find ApplyOp and MyOp by name in the HIR package. + let hir_package = &store.get(package_id).expect("package should exist").package; + let mut apply_op_local = None; + let mut my_op_local = None; + for (local_id, item) in hir_package.items.iter() { + if let ItemKind::Callable(decl) = &item.kind { + if decl.name.name.as_ref() == "ApplyOp" { + apply_op_local = Some(local_id); + } else if decl.name.name.as_ref() == "MyOp" { + my_op_local = Some(local_id); + } + } + } + let apply_op_local = apply_op_local.expect("ApplyOp should exist in HIR"); + let my_op_local = my_op_local.expect("MyOp should exist in HIR"); + + let apply_op_hir_id = qsc_hir::hir::ItemId { + package: package_id, + item: apply_op_local, + }; + + // Construct Value::Global for MyOp using FIR StoreItemId. + let my_op_fir_id = qsc_fir::fir::StoreItemId { + package: qsc_lowerer::map_hir_package_to_fir(package_id), + item: qsc_lowerer::map_hir_local_item_to_fir(my_op_local), + }; + let my_op_value = Value::Global(my_op_fir_id, FunctorApp::default()); + + // The synthetic Call path makes ApplyOp entry-reachable. Defunc specializes + // it to ApplyOp{MyOp}, and the pipeline transforms it fully. The original + // ApplyOp is pinned for DCE survival so fir_to_qir_from_callable can use + // the original ID with the original-shaped args. + let codegen_fir = + prepare_codegen_fir_from_callable_args(&store, apply_op_hir_id, &my_op_value, capabilities) + .unwrap_or_else(|errors| { + panic!( + "callable-args with arrow-input should survive DCE, got: {}", + format_interpret_errors(errors) + ) + }); + + let backend_callable = qsc_fir::fir::StoreItemId { + package: qsc_lowerer::map_hir_package_to_fir(apply_op_hir_id.package), + item: qsc_lowerer::map_hir_local_item_to_fir(apply_op_hir_id.item), + }; + + let qir = qsc_codegen::qir::fir_to_qir_from_callable( + &codegen_fir.fir_store, + capabilities, + &codegen_fir.compute_properties, + backend_callable, + my_op_value, + ) + .unwrap_or_else(|e| panic!("QIR generation from arrow-input callable should succeed: {e:?}")); + + expect![[r#" + %Result = type opaque + %Qubit = type opaque + + @0 = internal constant [4 x i8] c"0_r\00" + + define i64 @ENTRYPOINT__main() #0 { + block_0: + call void @__quantum__rt__initialize(i8* null) + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Result* inttoptr (i64 0 to %Result*)) + call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 0 to %Result*), i8* getelementptr inbounds ([4 x i8], [4 x i8]* @0, i64 0, i64 0)) + ret i64 0 + } + + declare void @__quantum__rt__initialize(i8*) + + declare void @__quantum__qis__h__body(%Qubit*) + + declare void @__quantum__qis__mresetz__body(%Qubit*, %Result*) #1 + + declare void @__quantum__rt__result_record_output(%Result*, i8*) + + attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="1" "required_num_results"="1" } + attributes #1 = { "irreversible" } + + ; module flags + + !llvm.module.flags = !{!0, !1, !2, !3, !4, !5} + + !0 = !{i32 1, !"qir_major_version", i32 1} + !1 = !{i32 7, !"qir_minor_version", i32 0} + !2 = !{i32 1, !"dynamic_qubit_management", i1 false} + !3 = !{i32 1, !"dynamic_result_management", i1 false} + !4 = !{i32 5, !"int_computations", !{!"i64"}} + !5 = !{i32 5, !"float_computations", !{!"double"}} + "#]] + .assert_eq(&qir); +} + +#[test] +fn callable_args_with_udt_wrapped_arrow_survives_dce() { + let source = indoc::indoc! {r#" + namespace Test { + newtype Config = (Op: Qubit => Unit, Data: Int); + operation Apply(cfg: Config) : Unit { + use q = Qubit(); + cfg::Op(q); + } + operation MyOp(q: Qubit) : Unit { H(q); } + } + "#}; + + let capabilities = TargetCapabilityFlags::Adaptive + | TargetCapabilityFlags::IntegerComputations + | TargetCapabilityFlags::FloatingPointComputations; + + let sources = source_map_from_source(source); + let language_features = LanguageFeatures::default(); + let (std_id, mut store) = crate::compile::package_store_with_stdlib(capabilities); + let dependencies: Vec<(PackageId, Option>)> = vec![(std_id, None)]; + + let (unit, errors) = crate::compile::compile( + &store, + &dependencies, + sources, + qsc_passes::PackageType::Lib, + capabilities, + language_features, + ); + assert!(errors.is_empty(), "compilation failed: {errors:?}"); + let package_id = store.insert(unit); + + let hir_package = &store.get(package_id).expect("package should exist").package; + let mut apply_local = None; + let mut my_op_local = None; + let mut config_udt_local = None; + for (local_id, item) in hir_package.items.iter() { + match &item.kind { + ItemKind::Callable(decl) if decl.name.name.as_ref() == "Apply" => { + apply_local = Some(local_id); + } + ItemKind::Callable(decl) if decl.name.name.as_ref() == "MyOp" => { + my_op_local = Some(local_id); + } + ItemKind::Ty(name, _) if name.name.as_ref() == "Config" => { + config_udt_local = Some(local_id); + } + _ => {} + } + } + let apply_local = apply_local.expect("Apply should exist in HIR"); + let my_op_local = my_op_local.expect("MyOp should exist in HIR"); + + let apply_hir_id = qsc_hir::hir::ItemId { + package: package_id, + item: apply_local, + }; + + let my_op_fir_id = qsc_fir::fir::StoreItemId { + package: qsc_lowerer::map_hir_package_to_fir(package_id), + item: qsc_lowerer::map_hir_local_item_to_fir(my_op_local), + }; + let my_op_value = Value::Global(my_op_fir_id, FunctorApp::default()); + + // Build a Config UDT value: Config(MyOp, 42) + // UDT values are Value::Tuple(Rc<[Value]>, Option>) + let config_fir_id = qsc_fir::fir::StoreItemId { + package: qsc_lowerer::map_hir_package_to_fir(package_id), + item: qsc_lowerer::map_hir_local_item_to_fir( + config_udt_local.expect("Config UDT should exist"), + ), + }; + let config_value = Value::Tuple( + vec![my_op_value, Value::Int(42)].into(), + Some(Rc::new(config_fir_id)), + ); + + let result = + prepare_codegen_fir_from_callable_args(&store, apply_hir_id, &config_value, capabilities); + match result { + Ok(_) => {} + Err(errors) => panic!( + "callable-args with UDT-wrapped arrow should survive DCE, got: {}", + format_interpret_errors(errors) + ), + } +} + +#[test] +fn callable_with_udt_wrapped_arrow_generates_qir_via_callable_args() { + let source = indoc::indoc! {r#" + namespace Test { + newtype Config = (Op: Qubit => Unit, Data: Int); + operation Apply(cfg: Config) : Result { + use q = Qubit(); + cfg::Op(q); + MResetZ(q) + } + operation MyOp(q: Qubit) : Unit { H(q); } + } + "#}; + + let capabilities = TargetCapabilityFlags::Adaptive + | TargetCapabilityFlags::IntegerComputations + | TargetCapabilityFlags::FloatingPointComputations; + + let sources = source_map_from_source(source); + let language_features = LanguageFeatures::default(); + let (std_id, mut store) = crate::compile::package_store_with_stdlib(capabilities); + let dependencies: Vec<(PackageId, Option>)> = vec![(std_id, None)]; + + let (unit, errors) = crate::compile::compile( + &store, + &dependencies, + sources, + qsc_passes::PackageType::Lib, + capabilities, + language_features, + ); + assert!(errors.is_empty(), "compilation failed: {errors:?}"); + let package_id = store.insert(unit); + + let hir_package = &store.get(package_id).expect("package should exist").package; + let mut apply_local = None; + let mut my_op_local = None; + let mut config_udt_local = None; + for (local_id, item) in hir_package.items.iter() { + match &item.kind { + ItemKind::Callable(decl) if decl.name.name.as_ref() == "Apply" => { + apply_local = Some(local_id); + } + ItemKind::Callable(decl) if decl.name.name.as_ref() == "MyOp" => { + my_op_local = Some(local_id); + } + ItemKind::Ty(name, _) if name.name.as_ref() == "Config" => { + config_udt_local = Some(local_id); + } + _ => {} + } + } + let apply_local = apply_local.expect("Apply should exist in HIR"); + let my_op_local = my_op_local.expect("MyOp should exist in HIR"); + + let apply_hir_id = qsc_hir::hir::ItemId { + package: package_id, + item: apply_local, + }; + + let my_op_fir_id = qsc_fir::fir::StoreItemId { + package: qsc_lowerer::map_hir_package_to_fir(package_id), + item: qsc_lowerer::map_hir_local_item_to_fir(my_op_local), + }; + let my_op_value = Value::Global(my_op_fir_id, FunctorApp::default()); + + let config_fir_id = qsc_fir::fir::StoreItemId { + package: qsc_lowerer::map_hir_package_to_fir(package_id), + item: qsc_lowerer::map_hir_local_item_to_fir( + config_udt_local.expect("Config UDT should exist"), + ), + }; + let config_value = Value::Tuple( + vec![my_op_value, Value::Int(42)].into(), + Some(Rc::new(config_fir_id)), + ); + + let codegen_fir = + prepare_codegen_fir_from_callable_args(&store, apply_hir_id, &config_value, capabilities) + .unwrap_or_else(|errors| { + panic!( + "callable-args with UDT-wrapped arrow should produce CodegenFir, got: {}", + format_interpret_errors(errors) + ) + }); + + let backend_callable = qsc_fir::fir::StoreItemId { + package: qsc_lowerer::map_hir_package_to_fir(apply_hir_id.package), + item: qsc_lowerer::map_hir_local_item_to_fir(apply_hir_id.item), + }; + + let qir = qsc_codegen::qir::fir_to_qir_from_callable( + &codegen_fir.fir_store, + capabilities, + &codegen_fir.compute_properties, + backend_callable, + config_value, + ) + .unwrap_or_else(|e| panic!("QIR generation from UDT-wrapped arrow should succeed: {e:?}")); + + expect![[r#" + %Result = type opaque + %Qubit = type opaque + + @0 = internal constant [4 x i8] c"0_r\00" + + define i64 @ENTRYPOINT__main() #0 { + block_0: + call void @__quantum__rt__initialize(i8* null) + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Result* inttoptr (i64 0 to %Result*)) + call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 0 to %Result*), i8* getelementptr inbounds ([4 x i8], [4 x i8]* @0, i64 0, i64 0)) + ret i64 0 + } + + declare void @__quantum__rt__initialize(i8*) + + declare void @__quantum__qis__h__body(%Qubit*) + + declare void @__quantum__qis__mresetz__body(%Qubit*, %Result*) #1 + + declare void @__quantum__rt__result_record_output(%Result*, i8*) + + attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="1" "required_num_results"="1" } + attributes #1 = { "irreversible" } + + ; module flags + + !llvm.module.flags = !{!0, !1, !2, !3, !4, !5} + + !0 = !{i32 1, !"qir_major_version", i32 1} + !1 = !{i32 7, !"qir_minor_version", i32 0} + !2 = !{i32 1, !"dynamic_qubit_management", i1 false} + !3 = !{i32 1, !"dynamic_result_management", i1 false} + !4 = !{i32 5, !"int_computations", !{!"i64"}} + !5 = !{i32 5, !"float_computations", !{!"double"}} + "#]] + .assert_eq(&qir); +} + +#[test] +fn callable_with_nested_udt_wrapped_arrow_generates_qir_via_callable_args() { + let source = indoc::indoc! {r#" + namespace Test { + newtype OpWrapper = (Op: Qubit => Unit); + newtype Config = (Inner: OpWrapper, Count: Int); + operation Apply(cfg: Config) : Result { + use q = Qubit(); + cfg::Inner::Op(q); + MResetZ(q) + } + operation MyOp(q: Qubit) : Unit { X(q); } + } + "#}; + + let capabilities = TargetCapabilityFlags::Adaptive + | TargetCapabilityFlags::IntegerComputations + | TargetCapabilityFlags::FloatingPointComputations; + + let sources = source_map_from_source(source); + let language_features = LanguageFeatures::default(); + let (std_id, mut store) = crate::compile::package_store_with_stdlib(capabilities); + let dependencies: Vec<(PackageId, Option>)> = vec![(std_id, None)]; + + let (unit, errors) = crate::compile::compile( + &store, + &dependencies, + sources, + qsc_passes::PackageType::Lib, + capabilities, + language_features, + ); + assert!(errors.is_empty(), "compilation failed: {errors:?}"); + let package_id = store.insert(unit); + + let hir_package = &store.get(package_id).expect("package should exist").package; + let mut apply_local = None; + let mut my_op_local = None; + let mut config_udt_local = None; + for (local_id, item) in hir_package.items.iter() { + match &item.kind { + ItemKind::Callable(decl) if decl.name.name.as_ref() == "Apply" => { + apply_local = Some(local_id); + } + ItemKind::Callable(decl) if decl.name.name.as_ref() == "MyOp" => { + my_op_local = Some(local_id); + } + ItemKind::Ty(name, _) if name.name.as_ref() == "Config" => { + config_udt_local = Some(local_id); + } + _ => {} + } + } + let apply_local = apply_local.expect("Apply should exist in HIR"); + let my_op_local = my_op_local.expect("MyOp should exist in HIR"); + + let apply_hir_id = qsc_hir::hir::ItemId { + package: package_id, + item: apply_local, + }; + + let my_op_fir_id = qsc_fir::fir::StoreItemId { + package: qsc_lowerer::map_hir_package_to_fir(package_id), + item: qsc_lowerer::map_hir_local_item_to_fir(my_op_local), + }; + let my_op_value = Value::Global(my_op_fir_id, FunctorApp::default()); + + let config_fir_id = qsc_fir::fir::StoreItemId { + package: qsc_lowerer::map_hir_package_to_fir(package_id), + item: qsc_lowerer::map_hir_local_item_to_fir( + config_udt_local.expect("Config UDT should exist"), + ), + }; + let config_value = Value::Tuple( + vec![my_op_value, Value::Int(5)].into(), + Some(Rc::new(config_fir_id)), + ); + + let codegen_fir = prepare_codegen_fir_from_callable_args( + &store, + apply_hir_id, + &config_value, + capabilities, + ) + .unwrap_or_else(|errors| { + panic!( + "callable-args with nested UDT-wrapped arrow should produce CodegenFir, got: {}", + format_interpret_errors(errors) + ) + }); + + let backend_callable = qsc_fir::fir::StoreItemId { + package: qsc_lowerer::map_hir_package_to_fir(apply_hir_id.package), + item: qsc_lowerer::map_hir_local_item_to_fir(apply_hir_id.item), + }; + + let qir = qsc_codegen::qir::fir_to_qir_from_callable( + &codegen_fir.fir_store, + capabilities, + &codegen_fir.compute_properties, + backend_callable, + config_value, + ) + .unwrap_or_else(|e| { + panic!("QIR generation from nested UDT-wrapped arrow should succeed: {e:?}") + }); + + expect![[r#" + %Result = type opaque + %Qubit = type opaque + + @0 = internal constant [4 x i8] c"0_r\00" + + define i64 @ENTRYPOINT__main() #0 { + block_0: + call void @__quantum__rt__initialize(i8* null) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Result* inttoptr (i64 0 to %Result*)) + call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 0 to %Result*), i8* getelementptr inbounds ([4 x i8], [4 x i8]* @0, i64 0, i64 0)) + ret i64 0 + } + + declare void @__quantum__rt__initialize(i8*) + + declare void @__quantum__qis__x__body(%Qubit*) + + declare void @__quantum__qis__mresetz__body(%Qubit*, %Result*) #1 + + declare void @__quantum__rt__result_record_output(%Result*, i8*) + + attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="1" "required_num_results"="1" } + attributes #1 = { "irreversible" } + + ; module flags + + !llvm.module.flags = !{!0, !1, !2, !3, !4, !5} + + !0 = !{i32 1, !"qir_major_version", i32 1} + !1 = !{i32 7, !"qir_minor_version", i32 0} + !2 = !{i32 1, !"dynamic_qubit_management", i1 false} + !3 = !{i32 1, !"dynamic_result_management", i1 false} + !4 = !{i32 5, !"int_computations", !{!"i64"}} + !5 = !{i32 5, !"float_computations", !{!"double"}} + "#]].assert_eq(&qir); +} + +// --------------------------------------------------------------------------- +// Synthetic-path and fallback-path coverage for callable-args codegen +// --------------------------------------------------------------------------- + +/// Helper: compile a lib package, locate named items, and return (`store`, `package_id`, `items_map`). +/// `item_names` maps display names to a bool: true = Callable, false = Ty (UDT). +fn compile_and_locate_items( + source: &str, + item_names: &[(&str, bool)], + capabilities: TargetCapabilityFlags, +) -> ( + crate::PackageStore, + PackageId, + FxHashMap, +) { + let sources = source_map_from_source(source); + let language_features = LanguageFeatures::default(); + let (std_id, mut store) = crate::compile::package_store_with_stdlib(capabilities); + let dependencies: Vec<(PackageId, Option>)> = vec![(std_id, None)]; + let (unit, errors) = crate::compile::compile( + &store, + &dependencies, + sources, + qsc_passes::PackageType::Lib, + capabilities, + language_features, + ); + assert!(errors.is_empty(), "compilation failed: {errors:?}"); + let package_id = store.insert(unit); + + let hir_package = &store.get(package_id).expect("package should exist").package; + let mut found = FxHashMap::default(); + for (local_id, item) in hir_package.items.iter() { + match &item.kind { + ItemKind::Callable(decl) => { + for &(name, is_callable) in item_names { + if is_callable && decl.name.name.as_ref() == name { + found.insert(name.to_string(), local_id); + } + } + } + ItemKind::Ty(name, _) => { + for &(item_name, is_callable) in item_names { + if !is_callable && name.name.as_ref() == item_name { + found.insert(item_name.to_string(), local_id); + } + } + } + _ => {} + } + } + for &(name, _) in item_names { + assert!( + found.contains_key(name), + "{name} should exist in HIR package" + ); + } + (store, package_id, found) +} + +/// Returns the target capabilities used by callable-args synthetic path tests. +fn adaptive_capabilities() -> TargetCapabilityFlags { + TargetCapabilityFlags::Adaptive + | TargetCapabilityFlags::IntegerComputations + | TargetCapabilityFlags::FloatingPointComputations +} + +/// Maps a HIR local item ID in the test package to its corresponding FIR item ID. +fn fir_id_for( + package_id: PackageId, + local_id: qsc_hir::hir::LocalItemId, +) -> qsc_fir::fir::StoreItemId { + qsc_fir::fir::StoreItemId { + package: qsc_lowerer::map_hir_package_to_fir(package_id), + item: qsc_lowerer::map_hir_local_item_to_fir(local_id), + } +} + +/// Builds a HIR item ID from a test package ID and local item ID. +fn hir_id_for(package_id: PackageId, local_id: qsc_hir::hir::LocalItemId) -> qsc_hir::hir::ItemId { + qsc_hir::hir::ItemId { + package: package_id, + item: local_id, + } +} + +/// Runs `prepare_codegen_fir_from_callable_args` and then `fir_to_qir_from_callable`, +/// returning the QIR string. +fn callable_args_to_qir( + store: &crate::PackageStore, + package_id: PackageId, + target_local: qsc_hir::hir::LocalItemId, + args: &Value, + capabilities: TargetCapabilityFlags, +) -> String { + let target_hir = hir_id_for(package_id, target_local); + let codegen_fir = prepare_codegen_fir_from_callable_args(store, target_hir, args, capabilities) + .unwrap_or_else(|errors| { + panic!( + "prepare_codegen_fir_from_callable_args failed: {}", + format_interpret_errors(errors) + ) + }); + let backend_callable = fir_id_for(package_id, target_local); + qsc_codegen::qir::fir_to_qir_from_callable( + &codegen_fir.fir_store, + capabilities, + &codegen_fir.compute_properties, + backend_callable, + args.clone(), + ) + .unwrap_or_else(|e| panic!("fir_to_qir_from_callable failed: {e:?}")) +} + +// ---- Synthetic path: arrow + non-callable params (tuple input) ---- + +#[test] +fn synthetic_path_arrow_and_int_tuple_generates_qir() { + // Target takes (op: Qubit => Unit, count: Int). Only the callable flows + // through `args`; count is provided as a plain Int value. + let source = indoc::indoc! {r#" + namespace Test { + operation RunOp(op : Qubit => Unit, count : Int) : Result { + use q = Qubit(); + for _ in 0..count - 1 { + op(q); + } + MResetZ(q) + } + operation MyH(q : Qubit) : Unit { H(q); } + } + "#}; + let caps = adaptive_capabilities(); + let (store, pkg, items) = + compile_and_locate_items(source, &[("RunOp", true), ("MyH", true)], caps); + + let my_h = Value::Global(fir_id_for(pkg, items["MyH"]), FunctorApp::default()); + let args = Value::Tuple(vec![my_h, Value::Int(3)].into(), None); + + let qir = callable_args_to_qir(&store, pkg, items["RunOp"], &args, caps); + // The QIR must contain an h__body call from the loop body. + assert!( + qir.contains("__quantum__qis__h__body"), + "expected h gate in QIR:\n{qir}" + ); + assert!( + qir.contains("__quantum__qis__mresetz__body"), + "expected mresetz in QIR:\n{qir}" + ); +} + +// ---- Synthetic path: two callable args in a tuple ---- + +#[test] +fn synthetic_path_two_arrow_args_generates_qir() { + // Target takes (op1: Qubit => Unit, op2: Qubit => Unit). Both are + // Global values — the synthetic Call must place both at their respective + // tuple positions. + let source = indoc::indoc! {r#" + namespace Test { + operation ApplyBoth(op1 : Qubit => Unit, op2 : Qubit => Unit) : Result { + use q = Qubit(); + op1(q); + op2(q); + MResetZ(q) + } + operation DoH(q : Qubit) : Unit { H(q); } + operation DoX(q : Qubit) : Unit { X(q); } + } + "#}; + let caps = adaptive_capabilities(); + let (store, pkg, items) = compile_and_locate_items( + source, + &[("ApplyBoth", true), ("DoH", true), ("DoX", true)], + caps, + ); + + let do_h = Value::Global(fir_id_for(pkg, items["DoH"]), FunctorApp::default()); + let do_x = Value::Global(fir_id_for(pkg, items["DoX"]), FunctorApp::default()); + let args = Value::Tuple(vec![do_h, do_x].into(), None); + + let qir = callable_args_to_qir(&store, pkg, items["ApplyBoth"], &args, caps); + assert!( + qir.contains("__quantum__qis__h__body"), + "expected H gate in QIR:\n{qir}" + ); + assert!( + qir.contains("__quantum__qis__x__body"), + "expected X gate in QIR:\n{qir}" + ); +} + +// ---- Synthetic path: arrow sandwiched between non-callable params ---- + +#[test] +fn synthetic_path_int_arrow_bool_tuple_generates_qir() { + // Target takes (n: Int, op: Qubit => Unit, flag: Bool). The callable is + // in the middle of the tuple — exercises the element-wise matching logic + // in `build_synthetic_args`. + let source = indoc::indoc! {r#" + namespace Test { + operation Middle(n : Int, op : Qubit => Unit, flag : Bool) : Result { + use q = Qubit(); + if flag { + for _ in 0..n - 1 { op(q); } + } + MResetZ(q) + } + operation DoX(q : Qubit) : Unit { X(q); } + } + "#}; + let caps = adaptive_capabilities(); + let (store, pkg, items) = + compile_and_locate_items(source, &[("Middle", true), ("DoX", true)], caps); + + let do_x = Value::Global(fir_id_for(pkg, items["DoX"]), FunctorApp::default()); + let args = Value::Tuple(vec![Value::Int(2), do_x, Value::Bool(true)].into(), None); + + let qir = callable_args_to_qir(&store, pkg, items["Middle"], &args, caps); + assert!( + qir.contains("__quantum__qis__x__body"), + "expected X gate in QIR:\n{qir}" + ); +} + +// ---- Synthetic path: no callable args (pure values) ---- + +#[test] +fn no_callable_args_takes_early_return_path() { + // When args contain no callable values, `prepare_codegen_fir_from_callable_args` + // takes the `concrete_callables.is_empty()` early return to `prepare_codegen_fir_from_callable`. + // This exercises that branch. + let source = indoc::indoc! {r#" + namespace Test { + operation Simple(n : Int) : Result { + use q = Qubit(); + for _ in 0..n - 1 { H(q); } + MResetZ(q) + } + } + "#}; + let caps = adaptive_capabilities(); + let (store, pkg, items) = compile_and_locate_items(source, &[("Simple", true)], caps); + + let args = Value::Int(3); + let target_hir = hir_id_for(pkg, items["Simple"]); + + // Should succeed without error — takes the no-callable early path. + let result = prepare_codegen_fir_from_callable_args(&store, target_hir, &args, caps); + assert!( + result.is_ok(), + "no-callable args should succeed: {:?}", + result.err().map(format_interpret_errors) + ); +} + +// ---- Synthetic path: struct with callable and non-callable fields ---- + +#[test] +fn synthetic_path_struct_with_callable_field_generates_qir() { + // `Config` is a newtype wrapping (Op: Qubit => Unit, Data: Int). + // The synthetic Call builder resolves the UDT's pure tuple shape so defunc + // can discover and specialize the callable field. + let source = indoc::indoc! {r#" + namespace Test { + newtype Config = (Op: Qubit => Unit, Data: Int); + operation Apply(cfg: Config) : Result { + use q = Qubit(); + cfg::Op(q); + MResetZ(q) + } + operation DoH(q: Qubit) : Unit { H(q); } + } + "#}; + let caps = adaptive_capabilities(); + let (store, pkg, items) = compile_and_locate_items( + source, + &[("Apply", true), ("DoH", true), ("Config", false)], + caps, + ); + + let do_h = Value::Global(fir_id_for(pkg, items["DoH"]), FunctorApp::default()); + let config = Value::Tuple( + vec![do_h, Value::Int(42)].into(), + Some(Rc::new(fir_id_for(pkg, items["Config"]))), + ); + + let qir = callable_args_to_qir(&store, pkg, items["Apply"], &config, caps); + assert!( + qir.contains("__quantum__qis__h__body"), + "expected H gate in QIR:\n{qir}" + ); +} + +// ---- No-callable path: struct with only non-callable fields ---- + +#[test] +fn struct_with_no_callable_fields_takes_early_return_path() { + // A UDT that contains no callable fields takes the `concrete_callables.is_empty()` + // early return. + let source = indoc::indoc! {r#" + namespace Test { + newtype Pair = (First: Int, Second: Int); + operation Sum(p: Pair) : Result { + use q = Qubit(); + let total = p::First + p::Second; + if total > 0 { H(q); } + MResetZ(q) + } + } + "#}; + let caps = adaptive_capabilities(); + let (store, pkg, items) = + compile_and_locate_items(source, &[("Sum", true), ("Pair", false)], caps); + + let pair = Value::Tuple( + vec![Value::Int(3), Value::Int(5)].into(), + Some(Rc::new(fir_id_for(pkg, items["Pair"]))), + ); + let target_hir = hir_id_for(pkg, items["Sum"]); + + let result = prepare_codegen_fir_from_callable_args(&store, target_hir, &pair, caps); + assert!( + result.is_ok(), + "struct with no callable fields should succeed: {:?}", + result.err().map(format_interpret_errors) + ); +} + +// ---- Synthetic path: single Global arg (not in a tuple) ---- + +#[test] +fn synthetic_path_single_global_arg_generates_qir() { + // The simplest synthetic path: a single callable arg, not wrapped in a tuple. + // `build_synthetic_args` hits the `Ty::Arrow` branch directly. + let source = indoc::indoc! {r#" + namespace Test { + operation Invoke(op : Qubit => Unit) : Result { + use q = Qubit(); + op(q); + MResetZ(q) + } + operation DoX(q : Qubit) : Unit { X(q); } + } + "#}; + let caps = adaptive_capabilities(); + let (store, pkg, items) = + compile_and_locate_items(source, &[("Invoke", true), ("DoX", true)], caps); + + let do_x = Value::Global(fir_id_for(pkg, items["DoX"]), FunctorApp::default()); + + let qir = callable_args_to_qir(&store, pkg, items["Invoke"], &do_x, caps); + assert!( + qir.contains("__quantum__qis__x__body"), + "expected X gate in QIR:\n{qir}" + ); +} + +#[test] +fn synthetic_path_captureless_closure_adjoint_preserves_functor() { + let source = indoc::indoc! {r#" + namespace Test { + operation Invoke(op : Qubit => Unit is Adj) : Result { + use q = Qubit(); + op(q); + MResetZ(q) + } + operation DoS(q : Qubit) : Unit is Adj { S(q); } + } + "#}; + let caps = adaptive_capabilities(); + let (store, pkg, items) = + compile_and_locate_items(source, &[("Invoke", true), ("DoS", true)], caps); + + let adjoint_do_s = Value::Closure(Box::new(qsc_eval::val::Closure { + fixed_args: Vec::::new().into(), + id: fir_id_for(pkg, items["DoS"]), + functor: FunctorApp { + adjoint: true, + controlled: 0, + }, + })); + + let target_hir = hir_id_for(pkg, items["Invoke"]); + let codegen_fir = + prepare_codegen_fir_from_callable_args(&store, target_hir, &adjoint_do_s, caps) + .unwrap_or_else(|errors| { + panic!( + "adjoint captureless closure should produce CodegenFir, got: {}", + format_interpret_errors(errors) + ) + }); + let entry = crate::codegen::qir::entry_from_codegen_fir(&codegen_fir); + let qir = qsc_codegen::qir::fir_to_qir( + &codegen_fir.fir_store, + caps, + &codegen_fir.compute_properties, + &entry, + ) + .unwrap_or_else(|e| panic!("synthetic entry QIR generation should succeed: {e:?}")); + assert!( + qir.contains("__quantum__qis__s__adj"), + "expected adjoint S gate in QIR:\n{qir}" + ); +} + +#[test] +fn synthetic_path_udt_wrapped_controlled_callable_preserves_functor() { + let source = indoc::indoc! {r#" + namespace Test { + newtype CtlBox = (Op: ((Qubit[], Qubit) => Unit), Tag: Int); + operation Invoke(b : CtlBox) : Result { + use (control, target) = (Qubit(), Qubit()); + b::Op([control], target); + Reset(control); + MResetZ(target) + } + operation DoX(q : Qubit) : Unit is Ctl { X(q); } + } + "#}; + let caps = adaptive_capabilities(); + let (store, pkg, items) = compile_and_locate_items( + source, + &[("Invoke", true), ("DoX", true), ("CtlBox", false)], + caps, + ); + + let controlled_do_x = Value::Global( + fir_id_for(pkg, items["DoX"]), + FunctorApp { + adjoint: false, + controlled: 1, + }, + ); + let boxed = Value::Tuple( + vec![controlled_do_x, Value::Int(0)].into(), + Some(Rc::new(fir_id_for(pkg, items["CtlBox"]))), + ); + + let target_hir = hir_id_for(pkg, items["Invoke"]); + let codegen_fir = prepare_codegen_fir_from_callable_args(&store, target_hir, &boxed, caps) + .unwrap_or_else(|errors| { + panic!( + "controlled UDT-wrapped callable should produce CodegenFir, got: {}", + format_interpret_errors(errors) + ) + }); + let entry = crate::codegen::qir::entry_from_codegen_fir(&codegen_fir); + let qir = qsc_codegen::qir::fir_to_qir( + &codegen_fir.fir_store, + caps, + &codegen_fir.compute_properties, + &entry, + ) + .unwrap_or_else(|e| panic!("synthetic entry QIR generation should succeed: {e:?}")); + assert!( + qir.contains("__quantum__qis__cx__body"), + "expected controlled X gate in QIR:\n{qir}" + ); +} + +// ---- Synthetic path: struct wrapping a callable field ---- + +#[test] +fn synthetic_path_single_field_struct_wrapping_callable_generates_qir() { + // Single-field UDT constructors are transparent in Value form: OpBox(DoH) + // is represented as the bare Global callable value. + let source = indoc::indoc! {r#" + namespace Test { + newtype OpBox = (Op: Qubit => Unit); + operation RunBoxed(b: OpBox) : Result { + use q = Qubit(); + b::Op(q); + MResetZ(q) + } + operation DoH(q: Qubit) : Unit { H(q); } + } + "#}; + let caps = adaptive_capabilities(); + let (store, pkg, items) = + compile_and_locate_items(source, &[("RunBoxed", true), ("DoH", true)], caps); + + let do_h = Value::Global(fir_id_for(pkg, items["DoH"]), FunctorApp::default()); + + let qir = callable_args_to_qir(&store, pkg, items["RunBoxed"], &do_h, caps); + assert!( + qir.contains("__quantum__qis__h__body"), + "expected H gate in QIR:\n{qir}" + ); +} + +#[test] +fn synthetic_path_struct_wrapping_callable_and_tag_generates_qir() { + // A newtype that wraps a callable and a non-callable field. + // This keeps tuple structure in the runtime Value while still exercising + // UDT pure-type discovery. + let source = indoc::indoc! {r#" + namespace Test { + newtype OpBox = (Op: Qubit => Unit, Tag: Int); + operation RunBoxed(b: OpBox) : Result { + use q = Qubit(); + b::Op(q); + MResetZ(q) + } + operation DoH(q: Qubit) : Unit { H(q); } + } + "#}; + let caps = adaptive_capabilities(); + let (store, pkg, items) = compile_and_locate_items( + source, + &[("RunBoxed", true), ("DoH", true), ("OpBox", false)], + caps, + ); + + let do_h = Value::Global(fir_id_for(pkg, items["DoH"]), FunctorApp::default()); + let boxed = Value::Tuple( + vec![do_h, Value::Int(0)].into(), + Some(Rc::new(fir_id_for(pkg, items["OpBox"]))), + ); + + let qir = callable_args_to_qir(&store, pkg, items["RunBoxed"], &boxed, caps); + assert!( + qir.contains("__quantum__qis__h__body"), + "expected H gate in QIR:\n{qir}" + ); +} + +#[test] +fn synthetic_path_udt_wrapped_adjoint_callable_preserves_functor() { + let source = indoc::indoc! {r#" + namespace Test { + newtype OpBox = (Op: Qubit => Unit is Adj, Tag: Int); + operation RunBoxed(b: OpBox) : Result { + use q = Qubit(); + b::Op(q); + MResetZ(q) + } + operation DoS(q: Qubit) : Unit is Adj { S(q); } + } + "#}; + let caps = adaptive_capabilities(); + let (store, pkg, items) = compile_and_locate_items( + source, + &[("RunBoxed", true), ("DoS", true), ("OpBox", false)], + caps, + ); + + let adjoint_do_s = Value::Global( + fir_id_for(pkg, items["DoS"]), + FunctorApp { + adjoint: true, + controlled: 0, + }, + ); + let boxed = Value::Tuple( + vec![adjoint_do_s, Value::Int(0)].into(), + Some(Rc::new(fir_id_for(pkg, items["OpBox"]))), + ); + + let target_hir = hir_id_for(pkg, items["RunBoxed"]); + let codegen_fir = prepare_codegen_fir_from_callable_args(&store, target_hir, &boxed, caps) + .unwrap_or_else(|errors| { + panic!( + "adjoint UDT-wrapped callable should produce CodegenFir, got: {}", + format_interpret_errors(errors) + ) + }); + let entry = crate::codegen::qir::entry_from_codegen_fir(&codegen_fir); + let qir = qsc_codegen::qir::fir_to_qir( + &codegen_fir.fir_store, + caps, + &codegen_fir.compute_properties, + &entry, + ) + .unwrap_or_else(|e| panic!("synthetic entry QIR generation should succeed: {e:?}")); + assert!( + qir.contains("__quantum__qis__s__adj"), + "expected adjoint S gate in QIR:\n{qir}" + ); +} + +// ---- Synthetic path: callable arg with additional non-callable tuple values ---- + +#[test] +fn synthetic_path_callable_with_double_and_string_generates_qir() { + // Target takes (factor: Double, op: Qubit => Unit, label: String). + // All three value types exercise different branches in `lower_value_to_expr`. + let source = indoc::indoc! {r#" + namespace Test { + operation Tagged(factor : Double, op : Qubit => Unit, label : String) : Result { + use q = Qubit(); + op(q); + MResetZ(q) + } + operation DoH(q : Qubit) : Unit { H(q); } + } + "#}; + let caps = adaptive_capabilities(); + let (store, pkg, items) = + compile_and_locate_items(source, &[("Tagged", true), ("DoH", true)], caps); + + let do_h = Value::Global(fir_id_for(pkg, items["DoH"]), FunctorApp::default()); + let args = Value::Tuple( + vec![Value::Double(1.5), do_h, Value::String("test".into())].into(), + None, + ); + + let qir = callable_args_to_qir(&store, pkg, items["Tagged"], &args, caps); + assert!( + qir.contains("__quantum__qis__h__body"), + "expected H gate in QIR:\n{qir}" + ); +} + +// ---- Synthetic path: nested struct (UDT inside UDT) with callable ---- + +#[test] +fn synthetic_path_nested_struct_with_callable_generates_qir() { + // Two levels of UDT wrapping: Config(Inner: OpBox, N: Int) where + // OpBox(Op: Qubit => Unit, Id: Int). This exercises UDT pure-type descent + // and nested field-chain replacement in defunctionalization. + // Inner UDTs need 2+ fields to avoid the single-field-UDT unwrap issue + // where the Value::Tuple shape misaligns with the erased type. + let source = indoc::indoc! {r#" + namespace Test { + newtype OpBox = (Op: Qubit => Unit, Id: Int); + newtype Config = (Inner: OpBox, N: Int); + operation RunConfig(cfg: Config) : Result { + use q = Qubit(); + cfg::Inner::Op(q); + MResetZ(q) + } + operation DoX(q: Qubit) : Unit { X(q); } + } + "#}; + let caps = adaptive_capabilities(); + let (store, pkg, items) = compile_and_locate_items( + source, + &[ + ("RunConfig", true), + ("DoX", true), + ("Config", false), + ("OpBox", false), + ], + caps, + ); + + let do_x = Value::Global(fir_id_for(pkg, items["DoX"]), FunctorApp::default()); + let inner = Value::Tuple( + vec![do_x, Value::Int(1)].into(), + Some(Rc::new(fir_id_for(pkg, items["OpBox"]))), + ); + let config = Value::Tuple( + vec![inner, Value::Int(5)].into(), + Some(Rc::new(fir_id_for(pkg, items["Config"]))), + ); + + let qir = callable_args_to_qir(&store, pkg, items["RunConfig"], &config, caps); + assert!( + qir.contains("__quantum__qis__x__body"), + "expected X gate in QIR:\n{qir}" + ); +} + +#[test] +fn synthetic_path_callable_field_taking_udt_with_callable_generates_qir() { + // Outer wraps a callable whose input is Inner, and Inner itself wraps a + // callable. This exercises UDT expansion through arrow input types, not + // just nested UDT fields that directly contain callable values. + let source = indoc::indoc! {r#" + namespace Test { + newtype Inner = (NestedOp: Qubit => Unit, Id: Int); + newtype Outer = (ApplyInner: Inner => Result, Id: Int); + + operation Invoke(outer: Outer) : Result { + let inner = Inner(DoH, 2); + outer::ApplyInner(inner) + } + + operation UseInner(inner: Inner) : Result { + use q = Qubit(); + inner::NestedOp(q); + MResetZ(q) + } + + operation DoH(q: Qubit) : Unit { H(q); } + } + "#}; + let caps = adaptive_capabilities(); + let (store, pkg, items) = compile_and_locate_items( + source, + &[ + ("Invoke", true), + ("UseInner", true), + ("DoH", true), + ("Inner", false), + ("Outer", false), + ], + caps, + ); + + let use_inner = Value::Global(fir_id_for(pkg, items["UseInner"]), FunctorApp::default()); + let outer = Value::Tuple( + vec![use_inner, Value::Int(1)].into(), + Some(Rc::new(fir_id_for(pkg, items["Outer"]))), + ); + + let target_hir = hir_id_for(pkg, items["Invoke"]); + let codegen_fir = prepare_codegen_fir_from_callable_args(&store, target_hir, &outer, caps) + .unwrap_or_else(|errors| { + panic!( + "callable field taking a UDT with a callable should produce CodegenFir, got: {}", + format_interpret_errors(errors) + ) + }); + let entry = crate::codegen::qir::entry_from_codegen_fir(&codegen_fir); + let qir = qsc_codegen::qir::fir_to_qir( + &codegen_fir.fir_store, + caps, + &codegen_fir.compute_properties, + &entry, + ) + .unwrap_or_else(|e| panic!("synthetic entry QIR generation should succeed: {e:?}")); + assert!( + qir.contains("__quantum__qis__h__body"), + "expected H gate in QIR:\n{qir}" + ); +} + +// ---- Synthetic path: tuple arg where only one element is callable ---- + +#[test] +fn synthetic_path_tuple_with_one_callable_among_many_scalars() { + // (Int, Int, Qubit => Unit, Bool, Int) — callable buried deep in a wide tuple. + let source = indoc::indoc! {r#" + namespace Test { + operation Wide(a : Int, b : Int, op : Qubit => Unit, flag : Bool, c : Int) : Result { + use q = Qubit(); + if flag { op(q); } + MResetZ(q) + } + operation DoH(q : Qubit) : Unit { H(q); } + } + "#}; + let caps = adaptive_capabilities(); + let (store, pkg, items) = + compile_and_locate_items(source, &[("Wide", true), ("DoH", true)], caps); + + let do_h = Value::Global(fir_id_for(pkg, items["DoH"]), FunctorApp::default()); + let args = Value::Tuple( + vec![ + Value::Int(1), + Value::Int(2), + do_h, + Value::Bool(true), + Value::Int(4), + ] + .into(), + None, + ); + + let qir = callable_args_to_qir(&store, pkg, items["Wide"], &args, caps); + assert!( + qir.contains("__quantum__qis__h__body"), + "expected H gate in QIR:\n{qir}" + ); +} + +// ---- Synthetic path: plain tuple with callable ---- + +#[test] +fn plain_tuple_with_callable_takes_synthetic_path() { + // A plain `Value::Tuple(_, None)` (no UDT tag) containing a callable takes + // the same synthetic path as UDT values. + let source = indoc::indoc! {r#" + namespace Test { + operation RunPair(op : Qubit => Unit, n : Int) : Result { + use q = Qubit(); + for _ in 0..n - 1 { op(q); } + MResetZ(q) + } + operation DoH(q : Qubit) : Unit { H(q); } + } + "#}; + let caps = adaptive_capabilities(); + let (store, pkg, items) = + compile_and_locate_items(source, &[("RunPair", true), ("DoH", true)], caps); + + let do_h = Value::Global(fir_id_for(pkg, items["DoH"]), FunctorApp::default()); + // Plain tuple — no UDT tag. + let args = Value::Tuple(vec![do_h, Value::Int(2)].into(), None); + + let qir = callable_args_to_qir(&store, pkg, items["RunPair"], &args, caps); + assert!( + qir.contains("__quantum__qis__h__body"), + "expected H gate in QIR:\n{qir}" + ); +} + +// ---- Synthetic path: struct with two callable fields ---- + +#[test] +fn synthetic_path_struct_with_two_callable_fields_generates_qir() { + // A newtype with two arrow fields. Both are wrapped in the UDT. + let source = indoc::indoc! {r#" + namespace Test { + newtype Ops = (First: Qubit => Unit, Second: Qubit => Unit); + operation RunOps(ops: Ops) : Result { + use q = Qubit(); + ops::First(q); + ops::Second(q); + MResetZ(q) + } + operation DoH(q: Qubit) : Unit { H(q); } + operation DoX(q: Qubit) : Unit { X(q); } + } + "#}; + let caps = adaptive_capabilities(); + let (store, pkg, items) = compile_and_locate_items( + source, + &[ + ("RunOps", true), + ("DoH", true), + ("DoX", true), + ("Ops", false), + ], + caps, + ); + + let do_h = Value::Global(fir_id_for(pkg, items["DoH"]), FunctorApp::default()); + let do_x = Value::Global(fir_id_for(pkg, items["DoX"]), FunctorApp::default()); + let ops = Value::Tuple( + vec![do_h, do_x].into(), + Some(Rc::new(fir_id_for(pkg, items["Ops"]))), + ); + + let qir = callable_args_to_qir(&store, pkg, items["RunOps"], &ops, caps); + assert!( + qir.contains("__quantum__qis__h__body"), + "expected H gate in QIR:\n{qir}" + ); + assert!( + qir.contains("__quantum__qis__x__body"), + "expected X gate in QIR:\n{qir}" + ); +} + +// ---- Synthetic path: callable with Pauli and Result args ---- + +#[test] +fn synthetic_path_callable_with_pauli_and_result_values() { + // Exercises the Pauli and Result branches of `lower_value_to_expr`. + let source = indoc::indoc! {r#" + namespace Test { + operation Measure(op : Qubit => Unit, basis : Pauli) : Result { + use q = Qubit(); + op(q); + MResetZ(q) + } + operation DoH(q : Qubit) : Unit { H(q); } + } + "#}; + let caps = adaptive_capabilities(); + let (store, pkg, items) = + compile_and_locate_items(source, &[("Measure", true), ("DoH", true)], caps); + + let do_h = Value::Global(fir_id_for(pkg, items["DoH"]), FunctorApp::default()); + let args = Value::Tuple( + vec![do_h, Value::Pauli(qsc_fir::fir::Pauli::Z)].into(), + None, + ); + + let qir = callable_args_to_qir(&store, pkg, items["Measure"], &args, caps); + assert!( + qir.contains("__quantum__qis__h__body"), + "expected H gate in QIR:\n{qir}" + ); +} + +mod base_profile { + use expect_test::expect; + use qsc_data_structures::target::TargetCapabilityFlags; + + use super::compile_source_to_qir; + static CAPABILITIES: std::sync::LazyLock = + std::sync::LazyLock::new(TargetCapabilityFlags::empty); + + #[test] + fn simple() { + let source = "namespace Test { + import Std.Math.*; + open QIR.Intrinsic; + @EntryPoint() + operation Main() : Result { + use q = Qubit(); + let pi_over_two = 4.0 / 2.0; + __quantum__qis__rz__body(pi_over_two, q); + mutable some_angle = ArcSin(0.0); + __quantum__qis__rz__body(some_angle, q); + set some_angle = ArcCos(-1.0) / PI(); + __quantum__qis__rz__body(some_angle, q); + __quantum__qis__mresetz__body(q) + } + }"; + + let qir = compile_source_to_qir(source, *CAPABILITIES); + expect![[r#" + %Result = type opaque + %Qubit = type opaque + + @0 = internal constant [4 x i8] c"0_r\00" + + define i64 @ENTRYPOINT__main() #0 { + block_0: + call void @__quantum__rt__initialize(i8* null) + call void @__quantum__qis__rz__body(double 2.0, %Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__rz__body(double 0.0, %Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__rz__body(double 1.0, %Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__m__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Result* inttoptr (i64 0 to %Result*)) + call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 0 to %Result*), i8* getelementptr inbounds ([4 x i8], [4 x i8]* @0, i64 0, i64 0)) + ret i64 0 + } + + declare void @__quantum__rt__initialize(i8*) + + declare void @__quantum__qis__rz__body(double, %Qubit*) + + declare void @__quantum__rt__result_record_output(%Result*, i8*) + + declare void @__quantum__qis__m__body(%Qubit*, %Result*) #1 + + attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="base_profile" "required_num_qubits"="1" "required_num_results"="1" } + attributes #1 = { "irreversible" } + + ; module flags + + !llvm.module.flags = !{!0, !1, !2, !3} + + !0 = !{i32 1, !"qir_major_version", i32 1} + !1 = !{i32 7, !"qir_minor_version", i32 0} + !2 = !{i32 1, !"dynamic_qubit_management", i1 false} + !3 = !{i32 1, !"dynamic_result_management", i1 false} + "#]] + .assert_eq(&qir); + } + + #[test] + fn qubit_reuse_triggers_reindexing() { + let source = "namespace Test { + @EntryPoint() + operation Main() : (Result, Result) { + use q = Qubit(); + (MResetZ(q), MResetZ(q)) + } + }"; + let qir = compile_source_to_qir(source, *CAPABILITIES); + expect![[r#" + %Result = type opaque + %Qubit = type opaque + + @0 = internal constant [4 x i8] c"0_t\00" + @1 = internal constant [6 x i8] c"1_t0r\00" + @2 = internal constant [6 x i8] c"2_t1r\00" + + define i64 @ENTRYPOINT__main() #0 { + block_0: + call void @__quantum__rt__initialize(i8* null) + call void @__quantum__qis__m__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Result* inttoptr (i64 0 to %Result*)) + call void @__quantum__qis__m__body(%Qubit* inttoptr (i64 1 to %Qubit*), %Result* inttoptr (i64 1 to %Result*)) + call void @__quantum__rt__tuple_record_output(i64 2, i8* getelementptr inbounds ([4 x i8], [4 x i8]* @0, i64 0, i64 0)) + call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 0 to %Result*), i8* getelementptr inbounds ([6 x i8], [6 x i8]* @1, i64 0, i64 0)) + call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 1 to %Result*), i8* getelementptr inbounds ([6 x i8], [6 x i8]* @2, i64 0, i64 0)) + ret i64 0 + } + + declare void @__quantum__rt__initialize(i8*) + + declare void @__quantum__rt__tuple_record_output(i64, i8*) + + declare void @__quantum__rt__result_record_output(%Result*, i8*) + + declare void @__quantum__qis__m__body(%Qubit*, %Result*) #1 + + attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="base_profile" "required_num_qubits"="2" "required_num_results"="2" } + attributes #1 = { "irreversible" } + + ; module flags + + !llvm.module.flags = !{!0, !1, !2, !3} + + !0 = !{i32 1, !"qir_major_version", i32 1} + !1 = !{i32 7, !"qir_minor_version", i32 0} + !2 = !{i32 1, !"dynamic_qubit_management", i1 false} + !3 = !{i32 1, !"dynamic_result_management", i1 false} + "#]].assert_eq(&qir); + } + + #[test] + fn qubit_measurements_get_deferred() { + let source = "namespace Test { + @EntryPoint() + operation Main() : Result[] { + use (q0, q1) = (Qubit(), Qubit()); + X(q0); + let r0 = MResetZ(q0); + X(q1); + let r1 = MResetZ(q1); + [r0, r1] + } + }"; + let qir = compile_source_to_qir(source, *CAPABILITIES); + expect![[r#" + %Result = type opaque + %Qubit = type opaque + + @0 = internal constant [4 x i8] c"0_a\00" + @1 = internal constant [6 x i8] c"1_a0r\00" + @2 = internal constant [6 x i8] c"2_a1r\00" + + define i64 @ENTRYPOINT__main() #0 { + block_0: + call void @__quantum__rt__initialize(i8* null) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 1 to %Qubit*)) + call void @__quantum__qis__m__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Result* inttoptr (i64 0 to %Result*)) + call void @__quantum__qis__m__body(%Qubit* inttoptr (i64 1 to %Qubit*), %Result* inttoptr (i64 1 to %Result*)) + call void @__quantum__rt__array_record_output(i64 2, i8* getelementptr inbounds ([4 x i8], [4 x i8]* @0, i64 0, i64 0)) + call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 0 to %Result*), i8* getelementptr inbounds ([6 x i8], [6 x i8]* @1, i64 0, i64 0)) + call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 1 to %Result*), i8* getelementptr inbounds ([6 x i8], [6 x i8]* @2, i64 0, i64 0)) + ret i64 0 + } + + declare void @__quantum__rt__initialize(i8*) + + declare void @__quantum__qis__x__body(%Qubit*) + + declare void @__quantum__rt__array_record_output(i64, i8*) + + declare void @__quantum__rt__result_record_output(%Result*, i8*) + + declare void @__quantum__qis__m__body(%Qubit*, %Result*) #1 + + attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="base_profile" "required_num_qubits"="2" "required_num_results"="2" } + attributes #1 = { "irreversible" } + + ; module flags + + !llvm.module.flags = !{!0, !1, !2, !3} + + !0 = !{i32 1, !"qir_major_version", i32 1} + !1 = !{i32 7, !"qir_minor_version", i32 0} + !2 = !{i32 1, !"dynamic_qubit_management", i1 false} + !3 = !{i32 1, !"dynamic_result_management", i1 false} + "#]].assert_eq(&qir); + } + + #[test] + fn qubit_id_swap_results_in_different_id_usage() { + let source = "namespace Test { + @EntryPoint() + operation Main() : (Result, Result) { + use (q0, q1) = (Qubit(), Qubit()); + X(q0); + Relabel([q0, q1], [q1, q0]); + X(q1); + (MResetZ(q0), MResetZ(q1)) + } + }"; + let qir = compile_source_to_qir(source, *CAPABILITIES); + expect![[r#" + %Result = type opaque + %Qubit = type opaque + + @0 = internal constant [4 x i8] c"0_t\00" + @1 = internal constant [6 x i8] c"1_t0r\00" + @2 = internal constant [6 x i8] c"2_t1r\00" + + define i64 @ENTRYPOINT__main() #0 { + block_0: + call void @__quantum__rt__initialize(i8* null) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__m__body(%Qubit* inttoptr (i64 1 to %Qubit*), %Result* inttoptr (i64 0 to %Result*)) + call void @__quantum__qis__m__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Result* inttoptr (i64 1 to %Result*)) + call void @__quantum__rt__tuple_record_output(i64 2, i8* getelementptr inbounds ([4 x i8], [4 x i8]* @0, i64 0, i64 0)) + call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 0 to %Result*), i8* getelementptr inbounds ([6 x i8], [6 x i8]* @1, i64 0, i64 0)) + call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 1 to %Result*), i8* getelementptr inbounds ([6 x i8], [6 x i8]* @2, i64 0, i64 0)) + ret i64 0 + } + + declare void @__quantum__rt__initialize(i8*) + + declare void @__quantum__qis__x__body(%Qubit*) + + declare void @__quantum__rt__tuple_record_output(i64, i8*) + + declare void @__quantum__rt__result_record_output(%Result*, i8*) + + declare void @__quantum__qis__m__body(%Qubit*, %Result*) #1 + + attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="base_profile" "required_num_qubits"="2" "required_num_results"="2" } + attributes #1 = { "irreversible" } + + ; module flags + + !llvm.module.flags = !{!0, !1, !2, !3} + + !0 = !{i32 1, !"qir_major_version", i32 1} + !1 = !{i32 7, !"qir_minor_version", i32 0} + !2 = !{i32 1, !"dynamic_qubit_management", i1 false} + !3 = !{i32 1, !"dynamic_result_management", i1 false} + "#]].assert_eq(&qir); + } + + #[test] + fn qubit_id_swap_across_reset_uses_updated_ids() { + let source = "namespace Test { + @EntryPoint() + operation Main() : (Result, Result) { + { + use (q0, q1) = (Qubit(), Qubit()); + X(q0); + Relabel([q0, q1], [q1, q0]); + X(q1); + Reset(q0); + Reset(q1); + } + use (q0, q1) = (Qubit(), Qubit()); + (MResetZ(q0), MResetZ(q1)) + } + }"; + let qir = compile_source_to_qir(source, *CAPABILITIES); + expect![[r#" + %Result = type opaque + %Qubit = type opaque + + @0 = internal constant [4 x i8] c"0_t\00" + @1 = internal constant [6 x i8] c"1_t0r\00" + @2 = internal constant [6 x i8] c"2_t1r\00" + + define i64 @ENTRYPOINT__main() #0 { + block_0: + call void @__quantum__rt__initialize(i8* null) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__m__body(%Qubit* inttoptr (i64 2 to %Qubit*), %Result* inttoptr (i64 0 to %Result*)) + call void @__quantum__qis__m__body(%Qubit* inttoptr (i64 1 to %Qubit*), %Result* inttoptr (i64 1 to %Result*)) + call void @__quantum__rt__tuple_record_output(i64 2, i8* getelementptr inbounds ([4 x i8], [4 x i8]* @0, i64 0, i64 0)) + call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 0 to %Result*), i8* getelementptr inbounds ([6 x i8], [6 x i8]* @1, i64 0, i64 0)) + call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 1 to %Result*), i8* getelementptr inbounds ([6 x i8], [6 x i8]* @2, i64 0, i64 0)) + ret i64 0 + } + + declare void @__quantum__rt__initialize(i8*) + + declare void @__quantum__qis__x__body(%Qubit*) + + declare void @__quantum__rt__tuple_record_output(i64, i8*) + + declare void @__quantum__rt__result_record_output(%Result*, i8*) + + declare void @__quantum__qis__m__body(%Qubit*, %Result*) #1 + + attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="base_profile" "required_num_qubits"="3" "required_num_results"="2" } + attributes #1 = { "irreversible" } + + ; module flags + + !llvm.module.flags = !{!0, !1, !2, !3} + + !0 = !{i32 1, !"qir_major_version", i32 1} + !1 = !{i32 7, !"qir_minor_version", i32 0} + !2 = !{i32 1, !"dynamic_qubit_management", i1 false} + !3 = !{i32 1, !"dynamic_result_management", i1 false} + "#]].assert_eq(&qir); + } + + #[test] + fn noise_intrinsic_generates_correct_qir() { + let source = "namespace Test { + operation Main() : Result { + use q = Qubit(); + test_noise_intrinsic(q); + MResetZ(q) + } + + @NoiseIntrinsic() + operation test_noise_intrinsic(target: Qubit) : Unit { + body intrinsic; + } + }"; + + let qir = compile_source_to_qir(source, *CAPABILITIES); + expect![[r#" + %Result = type opaque + %Qubit = type opaque + + @0 = internal constant [4 x i8] c"0_r\00" + + define i64 @ENTRYPOINT__main() #0 { + block_0: + call void @__quantum__rt__initialize(i8* null) + call void @test_noise_intrinsic(%Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__m__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Result* inttoptr (i64 0 to %Result*)) + call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 0 to %Result*), i8* getelementptr inbounds ([4 x i8], [4 x i8]* @0, i64 0, i64 0)) + ret i64 0 + } + + declare void @__quantum__rt__initialize(i8*) + + declare void @test_noise_intrinsic(%Qubit*) #2 + + declare void @__quantum__rt__result_record_output(%Result*, i8*) + + declare void @__quantum__qis__m__body(%Qubit*, %Result*) #1 + + attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="base_profile" "required_num_qubits"="1" "required_num_results"="1" } + attributes #1 = { "irreversible" } + attributes #2 = { "qdk_noise" } + + ; module flags + + !llvm.module.flags = !{!0, !1, !2, !3} + + !0 = !{i32 1, !"qir_major_version", i32 1} + !1 = !{i32 7, !"qir_minor_version", i32 0} + !2 = !{i32 1, !"dynamic_qubit_management", i1 false} + !3 = !{i32 1, !"dynamic_result_management", i1 false} + "#]].assert_eq(&qir); + } +} + +mod adaptive_profile { + use super::compile_source_to_qir; + use expect_test::expect; + use qsc_data_structures::target::TargetCapabilityFlags; + static CAPABILITIES: std::sync::LazyLock = + std::sync::LazyLock::new(|| TargetCapabilityFlags::Adaptive); + + #[test] + fn simple() { + let source = "namespace Test { + import Std.Math.*; + open QIR.Intrinsic; + @EntryPoint() + operation Main() : Result { + use q = Qubit(); + let pi_over_two = 4.0 / 2.0; + __quantum__qis__rz__body(pi_over_two, q); + mutable some_angle = ArcSin(0.0); + __quantum__qis__rz__body(some_angle, q); + set some_angle = ArcCos(-1.0) / PI(); + __quantum__qis__rz__body(some_angle, q); + __quantum__qis__mresetz__body(q) + } + }"; + let qir = compile_source_to_qir(source, *CAPABILITIES); + expect![[r#" + %Result = type opaque + %Qubit = type opaque + + @0 = internal constant [4 x i8] c"0_r\00" + + define i64 @ENTRYPOINT__main() #0 { + block_0: + call void @__quantum__rt__initialize(i8* null) + call void @__quantum__qis__rz__body(double 2.0, %Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__rz__body(double 0.0, %Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__rz__body(double 1.0, %Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Result* inttoptr (i64 0 to %Result*)) + call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 0 to %Result*), i8* getelementptr inbounds ([4 x i8], [4 x i8]* @0, i64 0, i64 0)) + ret i64 0 + } + + declare void @__quantum__rt__initialize(i8*) + + declare void @__quantum__qis__rz__body(double, %Qubit*) + + declare void @__quantum__qis__mresetz__body(%Qubit*, %Result*) #1 + + declare void @__quantum__rt__result_record_output(%Result*, i8*) + + attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="1" "required_num_results"="1" } + attributes #1 = { "irreversible" } + + ; module flags + + !llvm.module.flags = !{!0, !1, !2, !3} + + !0 = !{i32 1, !"qir_major_version", i32 1} + !1 = !{i32 7, !"qir_minor_version", i32 0} + !2 = !{i32 1, !"dynamic_qubit_management", i1 false} + !3 = !{i32 1, !"dynamic_result_management", i1 false} + "#]] + .assert_eq(&qir); + } + + #[test] + fn noise_intrinsic_generates_correct_qir() { + let source = "namespace Test { + operation Main() : Result { + use q = Qubit(); + test_noise_intrinsic(q); + MResetZ(q) + } + + @NoiseIntrinsic() + operation test_noise_intrinsic(target: Qubit) : Unit { + body intrinsic; + } + }"; + + let qir = compile_source_to_qir(source, *CAPABILITIES); + expect![[r#" + %Result = type opaque + %Qubit = type opaque + + @0 = internal constant [4 x i8] c"0_r\00" + + define i64 @ENTRYPOINT__main() #0 { + block_0: + call void @__quantum__rt__initialize(i8* null) + call void @test_noise_intrinsic(%Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Result* inttoptr (i64 0 to %Result*)) + call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 0 to %Result*), i8* getelementptr inbounds ([4 x i8], [4 x i8]* @0, i64 0, i64 0)) + ret i64 0 + } + + declare void @__quantum__rt__initialize(i8*) + + declare void @test_noise_intrinsic(%Qubit*) #2 + + declare void @__quantum__qis__mresetz__body(%Qubit*, %Result*) #1 + + declare void @__quantum__rt__result_record_output(%Result*, i8*) + + attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="1" "required_num_results"="1" } + attributes #1 = { "irreversible" } + attributes #2 = { "qdk_noise" } + + ; module flags + + !llvm.module.flags = !{!0, !1, !2, !3} + + !0 = !{i32 1, !"qir_major_version", i32 1} + !1 = !{i32 7, !"qir_minor_version", i32 0} + !2 = !{i32 1, !"dynamic_qubit_management", i1 false} + !3 = !{i32 1, !"dynamic_result_management", i1 false} + "#]].assert_eq(&qir); + } + + #[test] + fn custom_measurement_generates_correct_qir() { + let source = "namespace Test { + operation Main() : Result { + use q = Qubit(); + H(q); + __quantum__qis__mx__body(q) } - }"; + @Measurement() + operation __quantum__qis__mx__body(target: Qubit) : Result { + body intrinsic; + } + }"; let qir = compile_source_to_qir(source, *CAPABILITIES); expect![[r#" %Result = type opaque @@ -281,23 +2935,21 @@ mod base_profile { define i64 @ENTRYPOINT__main() #0 { block_0: call void @__quantum__rt__initialize(i8* null) - call void @__quantum__qis__rz__body(double 2.0, %Qubit* inttoptr (i64 0 to %Qubit*)) - call void @__quantum__qis__rz__body(double 0.0, %Qubit* inttoptr (i64 0 to %Qubit*)) - call void @__quantum__qis__rz__body(double 1.0, %Qubit* inttoptr (i64 0 to %Qubit*)) - call void @__quantum__qis__m__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Result* inttoptr (i64 0 to %Result*)) + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__mx__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Result* inttoptr (i64 0 to %Result*)) call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 0 to %Result*), i8* getelementptr inbounds ([4 x i8], [4 x i8]* @0, i64 0, i64 0)) ret i64 0 } declare void @__quantum__rt__initialize(i8*) - declare void @__quantum__qis__rz__body(double, %Qubit*) + declare void @__quantum__qis__h__body(%Qubit*) - declare void @__quantum__rt__result_record_output(%Result*, i8*) + declare void @__quantum__qis__mx__body(%Qubit*, %Result*) #1 - declare void @__quantum__qis__m__body(%Qubit*, %Result*) #1 + declare void @__quantum__rt__result_record_output(%Result*, i8*) - attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="base_profile" "required_num_qubits"="1" "required_num_results"="1" } + attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="1" "required_num_results"="1" } attributes #1 = { "irreversible" } ; module flags @@ -308,17 +2960,23 @@ mod base_profile { !1 = !{i32 7, !"qir_minor_version", i32 0} !2 = !{i32 1, !"dynamic_qubit_management", i1 false} !3 = !{i32 1, !"dynamic_result_management", i1 false} - "#]] - .assert_eq(&qir); + "#]].assert_eq(&qir); } #[test] - fn qubit_reuse_triggers_reindexing() { + fn custom_joint_measurement_generates_correct_qir() { let source = "namespace Test { - @EntryPoint() operation Main() : (Result, Result) { - use q = Qubit(); - (MResetZ(q), MResetZ(q)) + use q1 = Qubit(); + use q2 = Qubit(); + H(q1); + H(q2); + __quantum__qis__mzz__body(q1, q2) + } + + @Measurement() + operation __quantum__qis__mzz__body(q1: Qubit, q2: Qubit) : (Result, Result) { + body intrinsic; } }"; let qir = compile_source_to_qir(source, *CAPABILITIES); @@ -333,8 +2991,9 @@ mod base_profile { define i64 @ENTRYPOINT__main() #0 { block_0: call void @__quantum__rt__initialize(i8* null) - call void @__quantum__qis__m__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Result* inttoptr (i64 0 to %Result*)) - call void @__quantum__qis__m__body(%Qubit* inttoptr (i64 1 to %Qubit*), %Result* inttoptr (i64 1 to %Result*)) + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 1 to %Qubit*)) + call void @__quantum__qis__mzz__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Qubit* inttoptr (i64 1 to %Qubit*), %Result* inttoptr (i64 0 to %Result*), %Result* inttoptr (i64 1 to %Result*)) call void @__quantum__rt__tuple_record_output(i64 2, i8* getelementptr inbounds ([4 x i8], [4 x i8]* @0, i64 0, i64 0)) call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 0 to %Result*), i8* getelementptr inbounds ([6 x i8], [6 x i8]* @1, i64 0, i64 0)) call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 1 to %Result*), i8* getelementptr inbounds ([6 x i8], [6 x i8]* @2, i64 0, i64 0)) @@ -343,13 +3002,15 @@ mod base_profile { declare void @__quantum__rt__initialize(i8*) + declare void @__quantum__qis__h__body(%Qubit*) + + declare void @__quantum__qis__mzz__body(%Qubit*, %Qubit*, %Result*, %Result*) #1 + declare void @__quantum__rt__tuple_record_output(i64, i8*) declare void @__quantum__rt__result_record_output(%Result*, i8*) - declare void @__quantum__qis__m__body(%Qubit*, %Result*) #1 - - attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="base_profile" "required_num_qubits"="2" "required_num_results"="2" } + attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="2" "required_num_results"="2" } attributes #1 = { "irreversible" } ; module flags @@ -364,7 +3025,7 @@ mod base_profile { } #[test] - fn qubit_measurements_get_deferred() { + fn qubit_measurements_not_deferred() { let source = "namespace Test { @EntryPoint() operation Main() : Result[] { @@ -389,9 +3050,9 @@ mod base_profile { block_0: call void @__quantum__rt__initialize(i8* null) call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Result* inttoptr (i64 0 to %Result*)) call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 1 to %Qubit*)) - call void @__quantum__qis__m__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Result* inttoptr (i64 0 to %Result*)) - call void @__quantum__qis__m__body(%Qubit* inttoptr (i64 1 to %Qubit*), %Result* inttoptr (i64 1 to %Result*)) + call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 1 to %Qubit*), %Result* inttoptr (i64 1 to %Result*)) call void @__quantum__rt__array_record_output(i64 2, i8* getelementptr inbounds ([4 x i8], [4 x i8]* @0, i64 0, i64 0)) call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 0 to %Result*), i8* getelementptr inbounds ([6 x i8], [6 x i8]* @1, i64 0, i64 0)) call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 1 to %Result*), i8* getelementptr inbounds ([6 x i8], [6 x i8]* @2, i64 0, i64 0)) @@ -402,13 +3063,13 @@ mod base_profile { declare void @__quantum__qis__x__body(%Qubit*) + declare void @__quantum__qis__mresetz__body(%Qubit*, %Result*) #1 + declare void @__quantum__rt__array_record_output(i64, i8*) declare void @__quantum__rt__result_record_output(%Result*, i8*) - declare void @__quantum__qis__m__body(%Qubit*, %Result*) #1 - - attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="base_profile" "required_num_qubits"="2" "required_num_results"="2" } + attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="2" "required_num_results"="2" } attributes #1 = { "irreversible" } ; module flags @@ -421,17 +3082,126 @@ mod base_profile { !3 = !{i32 1, !"dynamic_result_management", i1 false} "#]].assert_eq(&qir); } +} + +mod adaptive_ri_profile { + + use expect_test::expect; + use qsc_data_structures::target::TargetCapabilityFlags; + + use super::{compile_source_to_qir, compile_source_to_qir_from_ast, compile_source_to_rir}; + static CAPABILITIES: std::sync::LazyLock = + std::sync::LazyLock::new(|| { + TargetCapabilityFlags::Adaptive | TargetCapabilityFlags::IntegerComputations + }); + + fn terminal_result_return_with_qubit_cleanup_source() -> &'static str { + indoc::indoc! {r#" + namespace Test { + @EntryPoint() + operation Main() : Result { + use q = Qubit(); + let r = M(q); + Reset(q); + return r; + } + } + "#} + } + + fn assert_terminal_result_return_with_qubit_cleanup_qir(qir: &str) { + expect![[r#" + %Result = type opaque + %Qubit = type opaque + + @0 = internal constant [4 x i8] c"0_r\00" + + define i64 @ENTRYPOINT__main() #0 { + block_0: + call void @__quantum__rt__initialize(i8* null) + call void @__quantum__qis__m__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Result* inttoptr (i64 0 to %Result*)) + call void @__quantum__qis__reset__body(%Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 0 to %Result*), i8* getelementptr inbounds ([4 x i8], [4 x i8]* @0, i64 0, i64 0)) + ret i64 0 + } + + declare void @__quantum__rt__initialize(i8*) + + declare void @__quantum__qis__m__body(%Qubit*, %Result*) #1 + + declare void @__quantum__qis__reset__body(%Qubit*) #1 + + declare void @__quantum__rt__result_record_output(%Result*, i8*) + + attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="1" "required_num_results"="1" } + attributes #1 = { "irreversible" } + + ; module flags + + !llvm.module.flags = !{!0, !1, !2, !3, !4} + + !0 = !{i32 1, !"qir_major_version", i32 1} + !1 = !{i32 7, !"qir_minor_version", i32 0} + !2 = !{i32 1, !"dynamic_qubit_management", i1 false} + !3 = !{i32 1, !"dynamic_result_management", i1 false} + !4 = !{i32 5, !"int_computations", !{!"i64"}} + "#]] + .assert_eq(qir); + } + + fn assert_terminal_result_return_with_qubit_cleanup_rir(program: &str, form: &str) { + assert!( + program.contains("name: __quantum__qis__m__body"), + "{form} RIR should include the measurement callable" + ); + assert!( + program.contains("name: __quantum__qis__reset__body"), + "{form} RIR should include the cleanup reset callable" + ); + assert!( + program.contains("name: __quantum__rt__result_record_output"), + "{form} RIR should include result output recording" + ); + assert!( + program.contains("num_qubits: 1"), + "{form} RIR should keep a single allocated qubit" + ); + assert!( + program.contains("num_results: 1"), + "{form} RIR should keep a single returned result" + ); + + let measurement_call = program + .find("args( Qubit(0), Result(0), )") + .unwrap_or_else(|| panic!("{form} RIR should contain the measurement call")); + let reset_call = program + .find("args( Qubit(0), )") + .unwrap_or_else(|| panic!("{form} RIR should contain the cleanup reset call")); + let output_call = program + .find("args( Result(0), Tag(") + .unwrap_or_else(|| panic!("{form} RIR should record the returned result")); + + assert!( + measurement_call < reset_call && reset_call < output_call, + "{form} RIR should measure, reset, and then record the returned result" + ); + } #[test] - fn qubit_id_swap_results_in_different_id_usage() { + fn simple() { let source = "namespace Test { + import Std.Math.*; + open QIR.Intrinsic; @EntryPoint() - operation Main() : (Result, Result) { - use (q0, q1) = (Qubit(), Qubit()); - X(q0); - Relabel([q0, q1], [q1, q0]); - X(q1); - (MResetZ(q0), MResetZ(q1)) + operation Main() : Result { + use q = Qubit(); + let pi_over_two = 4.0 / 2.0; + __quantum__qis__rz__body(pi_over_two, q); + mutable some_angle = ArcSin(0.0); + __quantum__qis__rz__body(some_angle, q); + set some_angle = ArcCos(-1.0) / PI(); + __quantum__qis__rz__body(some_angle, q); + __quantum__qis__mresetz__body(q) } }"; let qir = compile_source_to_qir(source, *CAPABILITIES); @@ -439,62 +3209,50 @@ mod base_profile { %Result = type opaque %Qubit = type opaque - @0 = internal constant [4 x i8] c"0_t\00" - @1 = internal constant [6 x i8] c"1_t0r\00" - @2 = internal constant [6 x i8] c"2_t1r\00" + @0 = internal constant [4 x i8] c"0_r\00" define i64 @ENTRYPOINT__main() #0 { block_0: call void @__quantum__rt__initialize(i8* null) - call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 0 to %Qubit*)) - call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 0 to %Qubit*)) - call void @__quantum__qis__m__body(%Qubit* inttoptr (i64 1 to %Qubit*), %Result* inttoptr (i64 0 to %Result*)) - call void @__quantum__qis__m__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Result* inttoptr (i64 1 to %Result*)) - call void @__quantum__rt__tuple_record_output(i64 2, i8* getelementptr inbounds ([4 x i8], [4 x i8]* @0, i64 0, i64 0)) - call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 0 to %Result*), i8* getelementptr inbounds ([6 x i8], [6 x i8]* @1, i64 0, i64 0)) - call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 1 to %Result*), i8* getelementptr inbounds ([6 x i8], [6 x i8]* @2, i64 0, i64 0)) + call void @__quantum__qis__rz__body(double 2.0, %Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__rz__body(double 0.0, %Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__rz__body(double 1.0, %Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Result* inttoptr (i64 0 to %Result*)) + call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 0 to %Result*), i8* getelementptr inbounds ([4 x i8], [4 x i8]* @0, i64 0, i64 0)) ret i64 0 } declare void @__quantum__rt__initialize(i8*) - declare void @__quantum__qis__x__body(%Qubit*) + declare void @__quantum__qis__rz__body(double, %Qubit*) - declare void @__quantum__rt__tuple_record_output(i64, i8*) + declare void @__quantum__qis__mresetz__body(%Qubit*, %Result*) #1 declare void @__quantum__rt__result_record_output(%Result*, i8*) - declare void @__quantum__qis__m__body(%Qubit*, %Result*) #1 - - attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="base_profile" "required_num_qubits"="2" "required_num_results"="2" } + attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="1" "required_num_results"="1" } attributes #1 = { "irreversible" } ; module flags - !llvm.module.flags = !{!0, !1, !2, !3} + !llvm.module.flags = !{!0, !1, !2, !3, !4} !0 = !{i32 1, !"qir_major_version", i32 1} !1 = !{i32 7, !"qir_minor_version", i32 0} !2 = !{i32 1, !"dynamic_qubit_management", i1 false} !3 = !{i32 1, !"dynamic_result_management", i1 false} - "#]].assert_eq(&qir); + !4 = !{i32 5, !"int_computations", !{!"i64"}} + "#]] + .assert_eq(&qir); } #[test] - fn qubit_id_swap_across_reset_uses_updated_ids() { + fn qubit_reuse_allowed() { let source = "namespace Test { @EntryPoint() operation Main() : (Result, Result) { - { - use (q0, q1) = (Qubit(), Qubit()); - X(q0); - Relabel([q0, q1], [q1, q0]); - X(q1); - Reset(q0); - Reset(q1); - } - use (q0, q1) = (Qubit(), Qubit()); - (MResetZ(q0), MResetZ(q1)) + use q = Qubit(); + (MResetZ(q), MResetZ(q)) } }"; let qir = compile_source_to_qir(source, *CAPABILITIES); @@ -509,10 +3267,8 @@ mod base_profile { define i64 @ENTRYPOINT__main() #0 { block_0: call void @__quantum__rt__initialize(i8* null) - call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 0 to %Qubit*)) - call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 0 to %Qubit*)) - call void @__quantum__qis__m__body(%Qubit* inttoptr (i64 2 to %Qubit*), %Result* inttoptr (i64 0 to %Result*)) - call void @__quantum__qis__m__body(%Qubit* inttoptr (i64 1 to %Qubit*), %Result* inttoptr (i64 1 to %Result*)) + call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Result* inttoptr (i64 0 to %Result*)) + call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Result* inttoptr (i64 1 to %Result*)) call void @__quantum__rt__tuple_record_output(i64 2, i8* getelementptr inbounds ([4 x i8], [4 x i8]* @0, i64 0, i64 0)) call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 0 to %Result*), i8* getelementptr inbounds ([6 x i8], [6 x i8]* @1, i64 0, i64 0)) call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 1 to %Result*), i8* getelementptr inbounds ([6 x i8], [6 x i8]* @2, i64 0, i64 0)) @@ -521,105 +3277,97 @@ mod base_profile { declare void @__quantum__rt__initialize(i8*) - declare void @__quantum__qis__x__body(%Qubit*) + declare void @__quantum__qis__mresetz__body(%Qubit*, %Result*) #1 declare void @__quantum__rt__tuple_record_output(i64, i8*) declare void @__quantum__rt__result_record_output(%Result*, i8*) - declare void @__quantum__qis__m__body(%Qubit*, %Result*) #1 - - attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="base_profile" "required_num_qubits"="3" "required_num_results"="2" } + attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="1" "required_num_results"="2" } attributes #1 = { "irreversible" } ; module flags - !llvm.module.flags = !{!0, !1, !2, !3} + !llvm.module.flags = !{!0, !1, !2, !3, !4} !0 = !{i32 1, !"qir_major_version", i32 1} !1 = !{i32 7, !"qir_minor_version", i32 0} !2 = !{i32 1, !"dynamic_qubit_management", i1 false} !3 = !{i32 1, !"dynamic_result_management", i1 false} + !4 = !{i32 5, !"int_computations", !{!"i64"}} "#]].assert_eq(&qir); } #[test] - fn noise_intrinsic_generates_correct_qir() { + fn qubit_measurements_not_deferred() { let source = "namespace Test { - operation Main() : Result { - use q = Qubit(); - test_noise_intrinsic(q); - MResetZ(q) - } - - @NoiseIntrinsic() - operation test_noise_intrinsic(target: Qubit) : Unit { - body intrinsic; + @EntryPoint() + operation Main() : Result[] { + use (q0, q1) = (Qubit(), Qubit()); + X(q0); + let r0 = MResetZ(q0); + X(q1); + let r1 = MResetZ(q1); + [r0, r1] } }"; - let qir = compile_source_to_qir(source, *CAPABILITIES); expect![[r#" %Result = type opaque %Qubit = type opaque - @0 = internal constant [4 x i8] c"0_r\00" + @0 = internal constant [4 x i8] c"0_a\00" + @1 = internal constant [6 x i8] c"1_a0r\00" + @2 = internal constant [6 x i8] c"2_a1r\00" define i64 @ENTRYPOINT__main() #0 { block_0: call void @__quantum__rt__initialize(i8* null) - call void @test_noise_intrinsic(%Qubit* inttoptr (i64 0 to %Qubit*)) - call void @__quantum__qis__m__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Result* inttoptr (i64 0 to %Result*)) - call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 0 to %Result*), i8* getelementptr inbounds ([4 x i8], [4 x i8]* @0, i64 0, i64 0)) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Result* inttoptr (i64 0 to %Result*)) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 1 to %Qubit*)) + call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 1 to %Qubit*), %Result* inttoptr (i64 1 to %Result*)) + call void @__quantum__rt__array_record_output(i64 2, i8* getelementptr inbounds ([4 x i8], [4 x i8]* @0, i64 0, i64 0)) + call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 0 to %Result*), i8* getelementptr inbounds ([6 x i8], [6 x i8]* @1, i64 0, i64 0)) + call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 1 to %Result*), i8* getelementptr inbounds ([6 x i8], [6 x i8]* @2, i64 0, i64 0)) ret i64 0 } declare void @__quantum__rt__initialize(i8*) - declare void @test_noise_intrinsic(%Qubit*) #2 + declare void @__quantum__qis__x__body(%Qubit*) - declare void @__quantum__rt__result_record_output(%Result*, i8*) + declare void @__quantum__qis__mresetz__body(%Qubit*, %Result*) #1 - declare void @__quantum__qis__m__body(%Qubit*, %Result*) #1 + declare void @__quantum__rt__array_record_output(i64, i8*) - attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="base_profile" "required_num_qubits"="1" "required_num_results"="1" } + declare void @__quantum__rt__result_record_output(%Result*, i8*) + + attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="2" "required_num_results"="2" } attributes #1 = { "irreversible" } - attributes #2 = { "qdk_noise" } ; module flags - !llvm.module.flags = !{!0, !1, !2, !3} + !llvm.module.flags = !{!0, !1, !2, !3, !4} !0 = !{i32 1, !"qir_major_version", i32 1} !1 = !{i32 7, !"qir_minor_version", i32 0} !2 = !{i32 1, !"dynamic_qubit_management", i1 false} !3 = !{i32 1, !"dynamic_result_management", i1 false} - "#]].assert_eq(&qir); - } -} - -mod adaptive_profile { - use super::compile_source_to_qir; - use expect_test::expect; - use qsc_data_structures::target::TargetCapabilityFlags; - static CAPABILITIES: std::sync::LazyLock = - std::sync::LazyLock::new(|| TargetCapabilityFlags::Adaptive); + !4 = !{i32 5, !"int_computations", !{!"i64"}} + "#]].assert_eq(&qir); + } #[test] - fn simple() { + fn qubit_id_swap_results_in_different_id_usage() { let source = "namespace Test { - import Std.Math.*; - open QIR.Intrinsic; @EntryPoint() - operation Main() : Result { - use q = Qubit(); - let pi_over_two = 4.0 / 2.0; - __quantum__qis__rz__body(pi_over_two, q); - mutable some_angle = ArcSin(0.0); - __quantum__qis__rz__body(some_angle, q); - set some_angle = ArcCos(-1.0) / PI(); - __quantum__qis__rz__body(some_angle, q); - __quantum__qis__mresetz__body(q) + operation Main() : (Result, Result) { + use (q0, q1) = (Qubit(), Qubit()); + X(q0); + Relabel([q0, q1], [q1, q0]); + X(q1); + (MResetZ(q0), MResetZ(q1)) } }"; let qir = compile_source_to_qir(source, *CAPABILITIES); @@ -627,108 +3375,132 @@ mod adaptive_profile { %Result = type opaque %Qubit = type opaque - @0 = internal constant [4 x i8] c"0_r\00" + @0 = internal constant [4 x i8] c"0_t\00" + @1 = internal constant [6 x i8] c"1_t0r\00" + @2 = internal constant [6 x i8] c"2_t1r\00" define i64 @ENTRYPOINT__main() #0 { block_0: call void @__quantum__rt__initialize(i8* null) - call void @__quantum__qis__rz__body(double 2.0, %Qubit* inttoptr (i64 0 to %Qubit*)) - call void @__quantum__qis__rz__body(double 0.0, %Qubit* inttoptr (i64 0 to %Qubit*)) - call void @__quantum__qis__rz__body(double 1.0, %Qubit* inttoptr (i64 0 to %Qubit*)) - call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Result* inttoptr (i64 0 to %Result*)) - call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 0 to %Result*), i8* getelementptr inbounds ([4 x i8], [4 x i8]* @0, i64 0, i64 0)) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 1 to %Qubit*), %Result* inttoptr (i64 0 to %Result*)) + call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Result* inttoptr (i64 1 to %Result*)) + call void @__quantum__rt__tuple_record_output(i64 2, i8* getelementptr inbounds ([4 x i8], [4 x i8]* @0, i64 0, i64 0)) + call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 0 to %Result*), i8* getelementptr inbounds ([6 x i8], [6 x i8]* @1, i64 0, i64 0)) + call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 1 to %Result*), i8* getelementptr inbounds ([6 x i8], [6 x i8]* @2, i64 0, i64 0)) ret i64 0 } declare void @__quantum__rt__initialize(i8*) - declare void @__quantum__qis__rz__body(double, %Qubit*) + declare void @__quantum__qis__x__body(%Qubit*) declare void @__quantum__qis__mresetz__body(%Qubit*, %Result*) #1 + declare void @__quantum__rt__tuple_record_output(i64, i8*) + declare void @__quantum__rt__result_record_output(%Result*, i8*) - attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="1" "required_num_results"="1" } + attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="2" "required_num_results"="2" } attributes #1 = { "irreversible" } ; module flags - !llvm.module.flags = !{!0, !1, !2, !3} + !llvm.module.flags = !{!0, !1, !2, !3, !4} !0 = !{i32 1, !"qir_major_version", i32 1} !1 = !{i32 7, !"qir_minor_version", i32 0} !2 = !{i32 1, !"dynamic_qubit_management", i1 false} !3 = !{i32 1, !"dynamic_result_management", i1 false} - "#]] - .assert_eq(&qir); + !4 = !{i32 5, !"int_computations", !{!"i64"}} + "#]].assert_eq(&qir); } #[test] - fn noise_intrinsic_generates_correct_qir() { + fn qubit_id_swap_across_reset_uses_updated_ids() { let source = "namespace Test { - operation Main() : Result { - use q = Qubit(); - test_noise_intrinsic(q); - MResetZ(q) - } - - @NoiseIntrinsic() - operation test_noise_intrinsic(target: Qubit) : Unit { - body intrinsic; + @EntryPoint() + operation Main() : (Result, Result) { + { + use (q0, q1) = (Qubit(), Qubit()); + X(q0); + Relabel([q0, q1], [q1, q0]); + X(q1); + Reset(q0); + Reset(q1); + } + use (q0, q1) = (Qubit(), Qubit()); + (MResetZ(q0), MResetZ(q1)) } }"; - let qir = compile_source_to_qir(source, *CAPABILITIES); expect![[r#" %Result = type opaque %Qubit = type opaque - @0 = internal constant [4 x i8] c"0_r\00" + @0 = internal constant [4 x i8] c"0_t\00" + @1 = internal constant [6 x i8] c"1_t0r\00" + @2 = internal constant [6 x i8] c"2_t1r\00" define i64 @ENTRYPOINT__main() #0 { block_0: call void @__quantum__rt__initialize(i8* null) - call void @test_noise_intrinsic(%Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__reset__body(%Qubit* inttoptr (i64 1 to %Qubit*)) + call void @__quantum__qis__reset__body(%Qubit* inttoptr (i64 0 to %Qubit*)) call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Result* inttoptr (i64 0 to %Result*)) - call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 0 to %Result*), i8* getelementptr inbounds ([4 x i8], [4 x i8]* @0, i64 0, i64 0)) + call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 1 to %Qubit*), %Result* inttoptr (i64 1 to %Result*)) + call void @__quantum__rt__tuple_record_output(i64 2, i8* getelementptr inbounds ([4 x i8], [4 x i8]* @0, i64 0, i64 0)) + call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 0 to %Result*), i8* getelementptr inbounds ([6 x i8], [6 x i8]* @1, i64 0, i64 0)) + call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 1 to %Result*), i8* getelementptr inbounds ([6 x i8], [6 x i8]* @2, i64 0, i64 0)) ret i64 0 } declare void @__quantum__rt__initialize(i8*) - declare void @test_noise_intrinsic(%Qubit*) #2 + declare void @__quantum__qis__x__body(%Qubit*) + + declare void @__quantum__qis__reset__body(%Qubit*) #1 declare void @__quantum__qis__mresetz__body(%Qubit*, %Result*) #1 + declare void @__quantum__rt__tuple_record_output(i64, i8*) + declare void @__quantum__rt__result_record_output(%Result*, i8*) - attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="1" "required_num_results"="1" } + attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="2" "required_num_results"="2" } attributes #1 = { "irreversible" } - attributes #2 = { "qdk_noise" } ; module flags - !llvm.module.flags = !{!0, !1, !2, !3} + !llvm.module.flags = !{!0, !1, !2, !3, !4} !0 = !{i32 1, !"qir_major_version", i32 1} !1 = !{i32 7, !"qir_minor_version", i32 0} !2 = !{i32 1, !"dynamic_qubit_management", i1 false} !3 = !{i32 1, !"dynamic_result_management", i1 false} + !4 = !{i32 5, !"int_computations", !{!"i64"}} "#]].assert_eq(&qir); } #[test] - fn custom_measurement_generates_correct_qir() { + fn qubit_id_swap_with_out_of_order_release_uses_correct_ids() { let source = "namespace Test { - operation Main() : Result { - use q = Qubit(); - H(q); - __quantum__qis__mx__body(q) - } - - @Measurement() - operation __quantum__qis__mx__body(target: Qubit) : Result { - body intrinsic; + @EntryPoint() + operation Main() : (Result, Result) { + let q0 = QIR.Runtime.__quantum__rt__qubit_allocate(); + let q1 = QIR.Runtime.__quantum__rt__qubit_allocate(); + let q2 = QIR.Runtime.__quantum__rt__qubit_allocate(); + X(q0); + X(q1); + X(q2); + Relabel([q0, q1], [q1, q0]); + QIR.Runtime.__quantum__rt__qubit_release(q0); + let q3 = QIR.Runtime.__quantum__rt__qubit_allocate(); + X(q3); + (MResetZ(q3), MResetZ(q1)) } }"; let qir = compile_source_to_qir(source, *CAPABILITIES); @@ -736,53 +3508,58 @@ mod adaptive_profile { %Result = type opaque %Qubit = type opaque - @0 = internal constant [4 x i8] c"0_r\00" + @0 = internal constant [4 x i8] c"0_t\00" + @1 = internal constant [6 x i8] c"1_t0r\00" + @2 = internal constant [6 x i8] c"2_t1r\00" define i64 @ENTRYPOINT__main() #0 { block_0: call void @__quantum__rt__initialize(i8* null) - call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 0 to %Qubit*)) - call void @__quantum__qis__mx__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Result* inttoptr (i64 0 to %Result*)) - call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 0 to %Result*), i8* getelementptr inbounds ([4 x i8], [4 x i8]* @0, i64 0, i64 0)) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 1 to %Qubit*)) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 2 to %Qubit*)) + call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 1 to %Qubit*)) + call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 1 to %Qubit*), %Result* inttoptr (i64 0 to %Result*)) + call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Result* inttoptr (i64 1 to %Result*)) + call void @__quantum__rt__tuple_record_output(i64 2, i8* getelementptr inbounds ([4 x i8], [4 x i8]* @0, i64 0, i64 0)) + call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 0 to %Result*), i8* getelementptr inbounds ([6 x i8], [6 x i8]* @1, i64 0, i64 0)) + call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 1 to %Result*), i8* getelementptr inbounds ([6 x i8], [6 x i8]* @2, i64 0, i64 0)) ret i64 0 } declare void @__quantum__rt__initialize(i8*) - declare void @__quantum__qis__h__body(%Qubit*) + declare void @__quantum__qis__x__body(%Qubit*) - declare void @__quantum__qis__mx__body(%Qubit*, %Result*) #1 + declare void @__quantum__qis__mresetz__body(%Qubit*, %Result*) #1 + + declare void @__quantum__rt__tuple_record_output(i64, i8*) declare void @__quantum__rt__result_record_output(%Result*, i8*) - attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="1" "required_num_results"="1" } + attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="3" "required_num_results"="2" } attributes #1 = { "irreversible" } ; module flags - !llvm.module.flags = !{!0, !1, !2, !3} + !llvm.module.flags = !{!0, !1, !2, !3, !4} !0 = !{i32 1, !"qir_major_version", i32 1} !1 = !{i32 7, !"qir_minor_version", i32 0} !2 = !{i32 1, !"dynamic_qubit_management", i1 false} !3 = !{i32 1, !"dynamic_result_management", i1 false} + !4 = !{i32 5, !"int_computations", !{!"i64"}} "#]].assert_eq(&qir); } #[test] - fn custom_joint_measurement_generates_correct_qir() { + fn dynamic_integer_with_branch_and_phi_supported() { let source = "namespace Test { - operation Main() : (Result, Result) { - use q1 = Qubit(); - use q2 = Qubit(); - H(q1); - H(q2); - __quantum__qis__mzz__body(q1, q2) - } - - @Measurement() - operation __quantum__qis__mzz__body(q1: Qubit, q2: Qubit) : (Result, Result) { - body intrinsic; + @EntryPoint() + operation Main() : Int { + use q = Qubit(); + H(q); + MResetZ(q) == Zero ? 0 | 1 } }"; let qir = compile_source_to_qir(source, *CAPABILITIES); @@ -790,19 +3567,23 @@ mod adaptive_profile { %Result = type opaque %Qubit = type opaque - @0 = internal constant [4 x i8] c"0_t\00" - @1 = internal constant [6 x i8] c"1_t0r\00" - @2 = internal constant [6 x i8] c"2_t1r\00" + @0 = internal constant [4 x i8] c"0_i\00" define i64 @ENTRYPOINT__main() #0 { block_0: call void @__quantum__rt__initialize(i8* null) call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 0 to %Qubit*)) - call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 1 to %Qubit*)) - call void @__quantum__qis__mzz__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Qubit* inttoptr (i64 1 to %Qubit*), %Result* inttoptr (i64 0 to %Result*), %Result* inttoptr (i64 1 to %Result*)) - call void @__quantum__rt__tuple_record_output(i64 2, i8* getelementptr inbounds ([4 x i8], [4 x i8]* @0, i64 0, i64 0)) - call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 0 to %Result*), i8* getelementptr inbounds ([6 x i8], [6 x i8]* @1, i64 0, i64 0)) - call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 1 to %Result*), i8* getelementptr inbounds ([6 x i8], [6 x i8]* @2, i64 0, i64 0)) + call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Result* inttoptr (i64 0 to %Result*)) + %var_0 = call i1 @__quantum__rt__read_result(%Result* inttoptr (i64 0 to %Result*)) + %var_1 = icmp eq i1 %var_0, false + br i1 %var_1, label %block_1, label %block_2 + block_1: + br label %block_3 + block_2: + br label %block_3 + block_3: + %var_4 = phi i64 [0, %block_1], [1, %block_2] + call void @__quantum__rt__int_record_output(i64 %var_4, i8* getelementptr inbounds ([4 x i8], [4 x i8]* @0, i64 0, i64 0)) ret i64 0 } @@ -810,37 +3591,39 @@ mod adaptive_profile { declare void @__quantum__qis__h__body(%Qubit*) - declare void @__quantum__qis__mzz__body(%Qubit*, %Qubit*, %Result*, %Result*) #1 + declare void @__quantum__qis__mresetz__body(%Qubit*, %Result*) #1 - declare void @__quantum__rt__tuple_record_output(i64, i8*) + declare i1 @__quantum__rt__read_result(%Result*) - declare void @__quantum__rt__result_record_output(%Result*, i8*) + declare void @__quantum__rt__int_record_output(i64, i8*) - attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="2" "required_num_results"="2" } + attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="1" "required_num_results"="1" } attributes #1 = { "irreversible" } ; module flags - !llvm.module.flags = !{!0, !1, !2, !3} + !llvm.module.flags = !{!0, !1, !2, !3, !4} !0 = !{i32 1, !"qir_major_version", i32 1} !1 = !{i32 7, !"qir_minor_version", i32 0} !2 = !{i32 1, !"dynamic_qubit_management", i1 false} !3 = !{i32 1, !"dynamic_result_management", i1 false} + !4 = !{i32 5, !"int_computations", !{!"i64"}} "#]].assert_eq(&qir); } #[test] - fn qubit_measurements_not_deferred() { + fn custom_reset_generates_correct_qir() { let source = "namespace Test { - @EntryPoint() - operation Main() : Result[] { - use (q0, q1) = (Qubit(), Qubit()); - X(q0); - let r0 = MResetZ(q0); - X(q1); - let r1 = MResetZ(q1); - [r0, r1] + operation Main() : Result { + use q = Qubit(); + __quantum__qis__custom_reset__body(q); + M(q) + } + + @Reset() + operation __quantum__qis__custom_reset__body(target: Qubit) : Unit { + body intrinsic; } }"; let qir = compile_source_to_qir(source, *CAPABILITIES); @@ -848,57 +3631,83 @@ mod adaptive_profile { %Result = type opaque %Qubit = type opaque - @0 = internal constant [4 x i8] c"0_a\00" - @1 = internal constant [6 x i8] c"1_a0r\00" - @2 = internal constant [6 x i8] c"2_a1r\00" + @0 = internal constant [4 x i8] c"0_r\00" define i64 @ENTRYPOINT__main() #0 { block_0: call void @__quantum__rt__initialize(i8* null) - call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 0 to %Qubit*)) - call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Result* inttoptr (i64 0 to %Result*)) - call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 1 to %Qubit*)) - call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 1 to %Qubit*), %Result* inttoptr (i64 1 to %Result*)) - call void @__quantum__rt__array_record_output(i64 2, i8* getelementptr inbounds ([4 x i8], [4 x i8]* @0, i64 0, i64 0)) - call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 0 to %Result*), i8* getelementptr inbounds ([6 x i8], [6 x i8]* @1, i64 0, i64 0)) - call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 1 to %Result*), i8* getelementptr inbounds ([6 x i8], [6 x i8]* @2, i64 0, i64 0)) + call void @__quantum__qis__custom_reset__body(%Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__m__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Result* inttoptr (i64 0 to %Result*)) + call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 0 to %Result*), i8* getelementptr inbounds ([4 x i8], [4 x i8]* @0, i64 0, i64 0)) ret i64 0 } declare void @__quantum__rt__initialize(i8*) - declare void @__quantum__qis__x__body(%Qubit*) - - declare void @__quantum__qis__mresetz__body(%Qubit*, %Result*) #1 + declare void @__quantum__qis__custom_reset__body(%Qubit*) #1 - declare void @__quantum__rt__array_record_output(i64, i8*) + declare void @__quantum__qis__m__body(%Qubit*, %Result*) #1 declare void @__quantum__rt__result_record_output(%Result*, i8*) - attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="2" "required_num_results"="2" } + attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="1" "required_num_results"="1" } attributes #1 = { "irreversible" } ; module flags - !llvm.module.flags = !{!0, !1, !2, !3} + !llvm.module.flags = !{!0, !1, !2, !3, !4} !0 = !{i32 1, !"qir_major_version", i32 1} !1 = !{i32 7, !"qir_minor_version", i32 0} !2 = !{i32 1, !"dynamic_qubit_management", i1 false} !3 = !{i32 1, !"dynamic_result_management", i1 false} - "#]].assert_eq(&qir); + !4 = !{i32 5, !"int_computations", !{!"i64"}} + "#]] + .assert_eq(&qir); } -} -mod adaptive_ri_profile { + #[test] + fn terminal_result_return_with_qubit_cleanup_generates_correct_qir() { + let qir = compile_source_to_qir( + terminal_result_return_with_qubit_cleanup_source(), + *CAPABILITIES, + ); + assert_terminal_result_return_with_qubit_cleanup_qir(&qir); + } + + #[test] + fn terminal_result_return_with_qubit_cleanup_generates_correct_qir_from_ast() { + let qir = compile_source_to_qir_from_ast( + terminal_result_return_with_qubit_cleanup_source(), + *CAPABILITIES, + ); + assert_terminal_result_return_with_qubit_cleanup_qir(&qir); + } + + #[test] + fn terminal_result_return_with_qubit_cleanup_generates_rir() { + let rir = compile_source_to_rir( + terminal_result_return_with_qubit_cleanup_source(), + *CAPABILITIES, + ); + let [raw, ssa] = rir.as_slice() else { + panic!("expected raw and SSA RIR programs"); + }; + + assert_terminal_result_return_with_qubit_cleanup_rir(raw, "raw"); + assert_terminal_result_return_with_qubit_cleanup_rir(ssa, "ssa"); + } +} +mod adaptive_rif_profile { + use super::compile_source_to_qir; use expect_test::expect; use qsc_data_structures::target::TargetCapabilityFlags; - - use super::compile_source_to_qir; static CAPABILITIES: std::sync::LazyLock = std::sync::LazyLock::new(|| { - TargetCapabilityFlags::Adaptive | TargetCapabilityFlags::IntegerComputations + TargetCapabilityFlags::Adaptive + | TargetCapabilityFlags::IntegerComputations + | TargetCapabilityFlags::FloatingPointComputations }); #[test] @@ -949,17 +3758,83 @@ mod adaptive_ri_profile { ; module flags - !llvm.module.flags = !{!0, !1, !2, !3, !4} + !llvm.module.flags = !{!0, !1, !2, !3, !4, !5} !0 = !{i32 1, !"qir_major_version", i32 1} !1 = !{i32 7, !"qir_minor_version", i32 0} !2 = !{i32 1, !"dynamic_qubit_management", i1 false} !3 = !{i32 1, !"dynamic_result_management", i1 false} !4 = !{i32 5, !"int_computations", !{!"i64"}} + !5 = !{i32 5, !"float_computations", !{!"double"}} "#]] .assert_eq(&qir); } + #[test] + fn tuple_comparison_generates_qir_after_pipeline() { + let qir = compile_source_to_qir( + indoc::indoc! {r#" + namespace Test { + @EntryPoint() + operation Main() : Bool { + use (q0, q1) = (Qubit(), Qubit()); + let lhs = (MResetZ(q0), MResetZ(q1)); + lhs == (Zero, Zero) + } + } + "#}, + *CAPABILITIES, + ); + + expect![[r#" + %Result = type opaque + %Qubit = type opaque + + @0 = internal constant [4 x i8] c"0_b\00" + + define i64 @ENTRYPOINT__main() #0 { + block_0: + call void @__quantum__rt__initialize(i8* null) + call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Result* inttoptr (i64 0 to %Result*)) + call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 1 to %Qubit*), %Result* inttoptr (i64 1 to %Result*)) + %var_0 = call i1 @__quantum__rt__read_result(%Result* inttoptr (i64 0 to %Result*)) + %var_1 = icmp eq i1 %var_0, false + br i1 %var_1, label %block_1, label %block_2 + block_1: + %var_3 = call i1 @__quantum__rt__read_result(%Result* inttoptr (i64 1 to %Result*)) + %var_4 = icmp eq i1 %var_3, false + br label %block_2 + block_2: + %var_6 = phi i1 [false, %block_0], [%var_4, %block_1] + call void @__quantum__rt__bool_record_output(i1 %var_6, i8* getelementptr inbounds ([4 x i8], [4 x i8]* @0, i64 0, i64 0)) + ret i64 0 + } + + declare void @__quantum__rt__initialize(i8*) + + declare void @__quantum__qis__mresetz__body(%Qubit*, %Result*) #1 + + declare i1 @__quantum__rt__read_result(%Result*) + + declare void @__quantum__rt__bool_record_output(i1, i8*) + + attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="2" "required_num_results"="2" } + attributes #1 = { "irreversible" } + + ; module flags + + !llvm.module.flags = !{!0, !1, !2, !3, !4, !5} + + !0 = !{i32 1, !"qir_major_version", i32 1} + !1 = !{i32 7, !"qir_minor_version", i32 0} + !2 = !{i32 1, !"dynamic_qubit_management", i1 false} + !3 = !{i32 1, !"dynamic_result_management", i1 false} + !4 = !{i32 5, !"int_computations", !{!"i64"}} + !5 = !{i32 5, !"float_computations", !{!"double"}} + "#]] + .assert_eq(&qir); + } + #[test] fn qubit_reuse_allowed() { let source = "namespace Test { @@ -1002,13 +3877,14 @@ mod adaptive_ri_profile { ; module flags - !llvm.module.flags = !{!0, !1, !2, !3, !4} + !llvm.module.flags = !{!0, !1, !2, !3, !4, !5} !0 = !{i32 1, !"qir_major_version", i32 1} !1 = !{i32 7, !"qir_minor_version", i32 0} !2 = !{i32 1, !"dynamic_qubit_management", i1 false} !3 = !{i32 1, !"dynamic_result_management", i1 false} !4 = !{i32 5, !"int_computations", !{!"i64"}} + !5 = !{i32 5, !"float_computations", !{!"double"}} "#]].assert_eq(&qir); } @@ -1062,13 +3938,14 @@ mod adaptive_ri_profile { ; module flags - !llvm.module.flags = !{!0, !1, !2, !3, !4} + !llvm.module.flags = !{!0, !1, !2, !3, !4, !5} !0 = !{i32 1, !"qir_major_version", i32 1} !1 = !{i32 7, !"qir_minor_version", i32 0} !2 = !{i32 1, !"dynamic_qubit_management", i1 false} !3 = !{i32 1, !"dynamic_result_management", i1 false} !4 = !{i32 5, !"int_computations", !{!"i64"}} + !5 = !{i32 5, !"float_computations", !{!"double"}} "#]].assert_eq(&qir); } @@ -1121,13 +3998,14 @@ mod adaptive_ri_profile { ; module flags - !llvm.module.flags = !{!0, !1, !2, !3, !4} + !llvm.module.flags = !{!0, !1, !2, !3, !4, !5} !0 = !{i32 1, !"qir_major_version", i32 1} !1 = !{i32 7, !"qir_minor_version", i32 0} !2 = !{i32 1, !"dynamic_qubit_management", i1 false} !3 = !{i32 1, !"dynamic_result_management", i1 false} !4 = !{i32 5, !"int_computations", !{!"i64"}} + !5 = !{i32 5, !"float_computations", !{!"double"}} "#]].assert_eq(&qir); } @@ -1189,13 +4067,14 @@ mod adaptive_ri_profile { ; module flags - !llvm.module.flags = !{!0, !1, !2, !3, !4} + !llvm.module.flags = !{!0, !1, !2, !3, !4, !5} !0 = !{i32 1, !"qir_major_version", i32 1} !1 = !{i32 7, !"qir_minor_version", i32 0} !2 = !{i32 1, !"dynamic_qubit_management", i1 false} !3 = !{i32 1, !"dynamic_result_management", i1 false} !4 = !{i32 5, !"int_computations", !{!"i64"}} + !5 = !{i32 5, !"float_computations", !{!"double"}} "#]].assert_eq(&qir); } @@ -1256,13 +4135,14 @@ mod adaptive_ri_profile { ; module flags - !llvm.module.flags = !{!0, !1, !2, !3, !4} + !llvm.module.flags = !{!0, !1, !2, !3, !4, !5} !0 = !{i32 1, !"qir_major_version", i32 1} !1 = !{i32 7, !"qir_minor_version", i32 0} !2 = !{i32 1, !"dynamic_qubit_management", i1 false} !3 = !{i32 1, !"dynamic_result_management", i1 false} !4 = !{i32 5, !"int_computations", !{!"i64"}} + !5 = !{i32 5, !"float_computations", !{!"double"}} "#]].assert_eq(&qir); } @@ -1281,7 +4161,192 @@ mod adaptive_ri_profile { %Result = type opaque %Qubit = type opaque - @0 = internal constant [4 x i8] c"0_i\00" + @0 = internal constant [4 x i8] c"0_i\00" + + define i64 @ENTRYPOINT__main() #0 { + block_0: + call void @__quantum__rt__initialize(i8* null) + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Result* inttoptr (i64 0 to %Result*)) + %var_0 = call i1 @__quantum__rt__read_result(%Result* inttoptr (i64 0 to %Result*)) + %var_1 = icmp eq i1 %var_0, false + br i1 %var_1, label %block_1, label %block_2 + block_1: + br label %block_3 + block_2: + br label %block_3 + block_3: + %var_4 = phi i64 [0, %block_1], [1, %block_2] + call void @__quantum__rt__int_record_output(i64 %var_4, i8* getelementptr inbounds ([4 x i8], [4 x i8]* @0, i64 0, i64 0)) + ret i64 0 + } + + declare void @__quantum__rt__initialize(i8*) + + declare void @__quantum__qis__h__body(%Qubit*) + + declare void @__quantum__qis__mresetz__body(%Qubit*, %Result*) #1 + + declare i1 @__quantum__rt__read_result(%Result*) + + declare void @__quantum__rt__int_record_output(i64, i8*) + + attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="1" "required_num_results"="1" } + attributes #1 = { "irreversible" } + + ; module flags + + !llvm.module.flags = !{!0, !1, !2, !3, !4, !5} + + !0 = !{i32 1, !"qir_major_version", i32 1} + !1 = !{i32 7, !"qir_minor_version", i32 0} + !2 = !{i32 1, !"dynamic_qubit_management", i1 false} + !3 = !{i32 1, !"dynamic_result_management", i1 false} + !4 = !{i32 5, !"int_computations", !{!"i64"}} + !5 = !{i32 5, !"float_computations", !{!"double"}} + "#]].assert_eq(&qir); + } + + #[test] + fn dynamic_double_with_branch_and_phi_supported() { + let source = "namespace Test { + @EntryPoint() + operation Main() : Double { + use q = Qubit(); + H(q); + MResetZ(q) == Zero ? 0.0 | 1.0 + } + }"; + let qir = compile_source_to_qir(source, *CAPABILITIES); + expect![[r#" + %Result = type opaque + %Qubit = type opaque + + @0 = internal constant [4 x i8] c"0_d\00" + + define i64 @ENTRYPOINT__main() #0 { + block_0: + call void @__quantum__rt__initialize(i8* null) + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Result* inttoptr (i64 0 to %Result*)) + %var_0 = call i1 @__quantum__rt__read_result(%Result* inttoptr (i64 0 to %Result*)) + %var_1 = icmp eq i1 %var_0, false + br i1 %var_1, label %block_1, label %block_2 + block_1: + br label %block_3 + block_2: + br label %block_3 + block_3: + %var_4 = phi double [0.0, %block_1], [1.0, %block_2] + call void @__quantum__rt__double_record_output(double %var_4, i8* getelementptr inbounds ([4 x i8], [4 x i8]* @0, i64 0, i64 0)) + ret i64 0 + } + + declare void @__quantum__rt__initialize(i8*) + + declare void @__quantum__qis__h__body(%Qubit*) + + declare void @__quantum__qis__mresetz__body(%Qubit*, %Result*) #1 + + declare i1 @__quantum__rt__read_result(%Result*) + + declare void @__quantum__rt__double_record_output(double, i8*) + + attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="1" "required_num_results"="1" } + attributes #1 = { "irreversible" } + + ; module flags + + !llvm.module.flags = !{!0, !1, !2, !3, !4, !5} + + !0 = !{i32 1, !"qir_major_version", i32 1} + !1 = !{i32 7, !"qir_minor_version", i32 0} + !2 = !{i32 1, !"dynamic_qubit_management", i1 false} + !3 = !{i32 1, !"dynamic_result_management", i1 false} + !4 = !{i32 5, !"int_computations", !{!"i64"}} + !5 = !{i32 5, !"float_computations", !{!"double"}} + "#]].assert_eq(&qir); + } + + #[test] + fn custom_reset_generates_correct_qir() { + let source = "namespace Test { + operation Main() : Result { + use q = Qubit(); + __quantum__qis__custom_reset__body(q); + M(q) + } + + @Reset() + operation __quantum__qis__custom_reset__body(target: Qubit) : Unit { + body intrinsic; + } + }"; + let qir = compile_source_to_qir(source, *CAPABILITIES); + expect![[r#" + %Result = type opaque + %Qubit = type opaque + + @0 = internal constant [4 x i8] c"0_r\00" + + define i64 @ENTRYPOINT__main() #0 { + block_0: + call void @__quantum__rt__initialize(i8* null) + call void @__quantum__qis__custom_reset__body(%Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__m__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Result* inttoptr (i64 0 to %Result*)) + call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 0 to %Result*), i8* getelementptr inbounds ([4 x i8], [4 x i8]* @0, i64 0, i64 0)) + ret i64 0 + } + + declare void @__quantum__rt__initialize(i8*) + + declare void @__quantum__qis__custom_reset__body(%Qubit*) #1 + + declare void @__quantum__qis__m__body(%Qubit*, %Result*) #1 + + declare void @__quantum__rt__result_record_output(%Result*, i8*) + + attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="1" "required_num_results"="1" } + attributes #1 = { "irreversible" } + + ; module flags + + !llvm.module.flags = !{!0, !1, !2, !3, !4, !5} + + !0 = !{i32 1, !"qir_major_version", i32 1} + !1 = !{i32 7, !"qir_minor_version", i32 0} + !2 = !{i32 1, !"dynamic_qubit_management", i1 false} + !3 = !{i32 1, !"dynamic_result_management", i1 false} + !4 = !{i32 5, !"int_computations", !{!"i64"}} + !5 = !{i32 5, !"float_computations", !{!"double"}} + "#]] + .assert_eq(&qir); + } + + #[test] + fn dynamic_double_intrinsic() { + let source = "namespace Test { + operation OpA(theta: Double, q : Qubit) : Unit { body intrinsic; } + @EntryPoint() + operation Main() : Double { + use q = Qubit(); + H(q); + let theta = MResetZ(q) == Zero ? 0.0 | 1.0; + OpA(1.0 + theta, q); + Rx(2.0 * theta, q); + Ry(theta / 3.0, q); + Rz(theta - 4.0, q); + OpA(theta, q); + Rx(theta, q); + theta + } + }"; + let qir = compile_source_to_qir(source, *CAPABILITIES); + expect![[r#" + %Result = type opaque + %Qubit = type opaque + + @0 = internal constant [4 x i8] c"0_d\00" define i64 @ENTRYPOINT__main() #0 { block_0: @@ -1296,8 +4361,18 @@ mod adaptive_ri_profile { block_2: br label %block_3 block_3: - %var_4 = phi i64 [0, %block_1], [1, %block_2] - call void @__quantum__rt__int_record_output(i64 %var_4, i8* getelementptr inbounds ([4 x i8], [4 x i8]* @0, i64 0, i64 0)) + %var_9 = phi double [0.0, %block_1], [1.0, %block_2] + %var_4 = fadd double 1.0, %var_9 + call void @OpA(double %var_4, %Qubit* inttoptr (i64 0 to %Qubit*)) + %var_5 = fmul double 2.0, %var_9 + call void @__quantum__qis__rx__body(double %var_5, %Qubit* inttoptr (i64 0 to %Qubit*)) + %var_6 = fdiv double %var_9, 3.0 + call void @__quantum__qis__ry__body(double %var_6, %Qubit* inttoptr (i64 0 to %Qubit*)) + %var_7 = fsub double %var_9, 4.0 + call void @__quantum__qis__rz__body(double %var_7, %Qubit* inttoptr (i64 0 to %Qubit*)) + call void @OpA(double %var_9, %Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__rx__body(double %var_9, %Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__rt__double_record_output(double %var_9, i8* getelementptr inbounds ([4 x i8], [4 x i8]* @0, i64 0, i64 0)) ret i64 0 } @@ -1309,721 +4384,1095 @@ mod adaptive_ri_profile { declare i1 @__quantum__rt__read_result(%Result*) - declare void @__quantum__rt__int_record_output(i64, i8*) + declare void @OpA(double, %Qubit*) + + declare void @__quantum__qis__rx__body(double, %Qubit*) + + declare void @__quantum__qis__ry__body(double, %Qubit*) + + declare void @__quantum__qis__rz__body(double, %Qubit*) + + declare void @__quantum__rt__double_record_output(double, i8*) attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="1" "required_num_results"="1" } attributes #1 = { "irreversible" } ; module flags - !llvm.module.flags = !{!0, !1, !2, !3, !4} + !llvm.module.flags = !{!0, !1, !2, !3, !4, !5} !0 = !{i32 1, !"qir_major_version", i32 1} !1 = !{i32 7, !"qir_minor_version", i32 0} !2 = !{i32 1, !"dynamic_qubit_management", i1 false} !3 = !{i32 1, !"dynamic_result_management", i1 false} !4 = !{i32 5, !"int_computations", !{!"i64"}} + !5 = !{i32 5, !"float_computations", !{!"double"}} "#]].assert_eq(&qir); } +} + +mod adaptive_rifla_profile { + use super::compile_source_to_qir; + use super::compile_source_to_qir_result; + use expect_test::expect; + use qsc_data_structures::target::TargetCapabilityFlags; + + static CAPABILITIES: std::sync::LazyLock = + std::sync::LazyLock::new(|| { + TargetCapabilityFlags::Adaptive + | TargetCapabilityFlags::IntegerComputations + | TargetCapabilityFlags::FloatingPointComputations + | TargetCapabilityFlags::BackwardsBranching + | TargetCapabilityFlags::StaticSizedArrays + }); #[test] - fn custom_reset_generates_correct_qir() { + fn nested_for_over_qubit_slice_succeeds() { let source = "namespace Test { - operation Main() : Result { - use q = Qubit(); - __quantum__qis__custom_reset__body(q); - M(q) - } - - @Reset() - operation __quantum__qis__custom_reset__body(target: Qubit) : Unit { - body intrinsic; + import Std.Intrinsic.*; + @EntryPoint() + operation Main() : Unit { + use qs = Qubit[3]; + X(qs[0]); + for _ in 1..2 { + for q in qs[1...] { + CNOT(qs[0], q); + } + } } }"; let qir = compile_source_to_qir(source, *CAPABILITIES); expect![[r#" - %Result = type opaque - %Qubit = type opaque - - @0 = internal constant [4 x i8] c"0_r\00" + @0 = internal constant [4 x i8] c"0_t\00" + @array0 = internal constant [2 x ptr] [ptr inttoptr (i64 1 to ptr), ptr inttoptr (i64 2 to ptr)] define i64 @ENTRYPOINT__main() #0 { block_0: - call void @__quantum__rt__initialize(i8* null) - call void @__quantum__qis__custom_reset__body(%Qubit* inttoptr (i64 0 to %Qubit*)) - call void @__quantum__qis__m__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Result* inttoptr (i64 0 to %Result*)) - call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 0 to %Result*), i8* getelementptr inbounds ([4 x i8], [4 x i8]* @0, i64 0, i64 0)) + %var_1 = alloca i64 + %var_3 = alloca i1 + %var_4 = alloca i64 + call void @__quantum__rt__initialize(ptr null) + call void @__quantum__qis__x__body(ptr inttoptr (i64 0 to ptr)) + store i64 1, ptr %var_1 + br label %block_1 + block_1: + %var_11 = load i64, ptr %var_1 + %var_2 = icmp sle i64 %var_11, 2 + store i1 true, ptr %var_3 + br i1 %var_2, label %block_2, label %block_3 + block_2: + %var_14 = load i1, ptr %var_3 + br i1 %var_14, label %block_4, label %block_5 + block_3: + store i1 false, ptr %var_3 + br label %block_2 + block_4: + store i64 0, ptr %var_4 + br label %block_6 + block_5: + call void @__quantum__rt__tuple_record_output(i64 0, ptr @0) ret i64 0 + block_6: + %var_16 = load i64, ptr %var_4 + %var_5 = icmp slt i64 %var_16, 2 + br i1 %var_5, label %block_7, label %block_8 + block_7: + %var_19 = load i64, ptr %var_4 + %var_6 = getelementptr ptr, ptr @array0, i64 %var_19 + %var_20 = load ptr, ptr %var_6 + call void @__quantum__qis__cx__body(ptr inttoptr (i64 0 to ptr), ptr %var_20) + %var_8 = add i64 %var_19, 1 + store i64 %var_8, ptr %var_4 + br label %block_6 + block_8: + %var_17 = load i64, ptr %var_1 + %var_9 = add i64 %var_17, 1 + store i64 %var_9, ptr %var_1 + br label %block_1 } - declare void @__quantum__rt__initialize(i8*) + declare void @__quantum__rt__initialize(ptr) - declare void @__quantum__qis__custom_reset__body(%Qubit*) #1 + declare void @__quantum__qis__x__body(ptr) - declare void @__quantum__qis__m__body(%Qubit*, %Result*) #1 + declare void @__quantum__qis__cx__body(ptr, ptr) - declare void @__quantum__rt__result_record_output(%Result*, i8*) + declare void @__quantum__rt__tuple_record_output(i64, ptr) - attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="1" "required_num_results"="1" } + attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="3" "required_num_results"="0" } attributes #1 = { "irreversible" } ; module flags - !llvm.module.flags = !{!0, !1, !2, !3, !4} + !llvm.module.flags = !{!0, !1, !2, !3, !4, !5, !6, !7} - !0 = !{i32 1, !"qir_major_version", i32 1} - !1 = !{i32 7, !"qir_minor_version", i32 0} + !0 = !{i32 1, !"qir_major_version", i32 2} + !1 = !{i32 7, !"qir_minor_version", i32 1} !2 = !{i32 1, !"dynamic_qubit_management", i1 false} !3 = !{i32 1, !"dynamic_result_management", i1 false} !4 = !{i32 5, !"int_computations", !{!"i64"}} + !5 = !{i32 5, !"float_computations", !{!"double"}} + !6 = !{i32 7, !"backwards_branching", i2 3} + !7 = !{i32 1, !"arrays", i1 true} "#]] - .assert_eq(&qir); + .assert_eq(&qir); } -} - -mod adaptive_rif_profile { - use super::compile_source_to_qir; - use expect_test::expect; - use qsc_data_structures::target::TargetCapabilityFlags; - static CAPABILITIES: std::sync::LazyLock = - std::sync::LazyLock::new(|| { - TargetCapabilityFlags::Adaptive - | TargetCapabilityFlags::IntegerComputations - | TargetCapabilityFlags::FloatingPointComputations - }); #[test] - fn simple() { + fn constant_folding_pattern_succeeds() { let source = "namespace Test { - import Std.Math.*; - open QIR.Intrinsic; + import Std.Intrinsic.*; @EntryPoint() - operation Main() : Result { - use q = Qubit(); - let pi_over_two = 4.0 / 2.0; - __quantum__qis__rz__body(pi_over_two, q); - mutable some_angle = ArcSin(0.0); - __quantum__qis__rz__body(some_angle, q); - set some_angle = ArcCos(-1.0) / PI(); - __quantum__qis__rz__body(some_angle, q); - __quantum__qis__mresetz__body(q) + operation Main() : Result[] { + use qs = Qubit[3]; + let iterations = 2; + X(qs[0]); + for _ in 1..iterations { + for q in qs[1...] { + CNOT(qs[0], q); + } + } + MResetEachZ(qs) } }"; let qir = compile_source_to_qir(source, *CAPABILITIES); expect![[r#" - %Result = type opaque - %Qubit = type opaque - - @0 = internal constant [4 x i8] c"0_r\00" + @0 = internal constant [4 x i8] c"0_a\00" + @1 = internal constant [6 x i8] c"1_a0r\00" + @2 = internal constant [6 x i8] c"2_a1r\00" + @3 = internal constant [6 x i8] c"3_a2r\00" + @array0 = internal constant [2 x ptr] [ptr inttoptr (i64 1 to ptr), ptr inttoptr (i64 2 to ptr)] define i64 @ENTRYPOINT__main() #0 { block_0: - call void @__quantum__rt__initialize(i8* null) - call void @__quantum__qis__rz__body(double 2.0, %Qubit* inttoptr (i64 0 to %Qubit*)) - call void @__quantum__qis__rz__body(double 0.0, %Qubit* inttoptr (i64 0 to %Qubit*)) - call void @__quantum__qis__rz__body(double 1.0, %Qubit* inttoptr (i64 0 to %Qubit*)) - call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Result* inttoptr (i64 0 to %Result*)) - call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 0 to %Result*), i8* getelementptr inbounds ([4 x i8], [4 x i8]* @0, i64 0, i64 0)) + %var_1 = alloca i64 + %var_3 = alloca i1 + %var_4 = alloca i64 + call void @__quantum__rt__initialize(ptr null) + call void @__quantum__qis__x__body(ptr inttoptr (i64 0 to ptr)) + store i64 1, ptr %var_1 + br label %block_1 + block_1: + %var_11 = load i64, ptr %var_1 + %var_2 = icmp sle i64 %var_11, 2 + store i1 true, ptr %var_3 + br i1 %var_2, label %block_2, label %block_3 + block_2: + %var_14 = load i1, ptr %var_3 + br i1 %var_14, label %block_4, label %block_5 + block_3: + store i1 false, ptr %var_3 + br label %block_2 + block_4: + store i64 0, ptr %var_4 + br label %block_6 + block_5: + call void @__quantum__qis__mresetz__body(ptr inttoptr (i64 0 to ptr), ptr inttoptr (i64 0 to ptr)) + call void @__quantum__qis__mresetz__body(ptr inttoptr (i64 1 to ptr), ptr inttoptr (i64 1 to ptr)) + call void @__quantum__qis__mresetz__body(ptr inttoptr (i64 2 to ptr), ptr inttoptr (i64 2 to ptr)) + call void @__quantum__rt__array_record_output(i64 3, ptr @0) + call void @__quantum__rt__result_record_output(ptr inttoptr (i64 0 to ptr), ptr @1) + call void @__quantum__rt__result_record_output(ptr inttoptr (i64 1 to ptr), ptr @2) + call void @__quantum__rt__result_record_output(ptr inttoptr (i64 2 to ptr), ptr @3) ret i64 0 + block_6: + %var_16 = load i64, ptr %var_4 + %var_5 = icmp slt i64 %var_16, 2 + br i1 %var_5, label %block_7, label %block_8 + block_7: + %var_19 = load i64, ptr %var_4 + %var_6 = getelementptr ptr, ptr @array0, i64 %var_19 + %var_20 = load ptr, ptr %var_6 + call void @__quantum__qis__cx__body(ptr inttoptr (i64 0 to ptr), ptr %var_20) + %var_8 = add i64 %var_19, 1 + store i64 %var_8, ptr %var_4 + br label %block_6 + block_8: + %var_17 = load i64, ptr %var_1 + %var_9 = add i64 %var_17, 1 + store i64 %var_9, ptr %var_1 + br label %block_1 } - declare void @__quantum__rt__initialize(i8*) + declare void @__quantum__rt__initialize(ptr) - declare void @__quantum__qis__rz__body(double, %Qubit*) + declare void @__quantum__qis__x__body(ptr) - declare void @__quantum__qis__mresetz__body(%Qubit*, %Result*) #1 + declare void @__quantum__qis__cx__body(ptr, ptr) - declare void @__quantum__rt__result_record_output(%Result*, i8*) + declare void @__quantum__qis__mresetz__body(ptr, ptr) #1 - attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="1" "required_num_results"="1" } + declare void @__quantum__rt__array_record_output(i64, ptr) + + declare void @__quantum__rt__result_record_output(ptr, ptr) + + attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="3" "required_num_results"="3" } attributes #1 = { "irreversible" } ; module flags - !llvm.module.flags = !{!0, !1, !2, !3, !4, !5} + !llvm.module.flags = !{!0, !1, !2, !3, !4, !5, !6, !7} - !0 = !{i32 1, !"qir_major_version", i32 1} - !1 = !{i32 7, !"qir_minor_version", i32 0} + !0 = !{i32 1, !"qir_major_version", i32 2} + !1 = !{i32 7, !"qir_minor_version", i32 1} !2 = !{i32 1, !"dynamic_qubit_management", i1 false} !3 = !{i32 1, !"dynamic_result_management", i1 false} !4 = !{i32 5, !"int_computations", !{!"i64"}} !5 = !{i32 5, !"float_computations", !{!"double"}} + !6 = !{i32 7, !"backwards_branching", i2 3} + !7 = !{i32 1, !"arrays", i1 true} "#]] - .assert_eq(&qir); + .assert_eq(&qir); } #[test] - fn qubit_reuse_allowed() { + fn three_qubit_repetition_code_pattern_succeeds() { let source = "namespace Test { + import Std.Intrinsic.*; + operation ApplyRotationalIdentity(register : Qubit[]) : Unit { + let theta = 2.0 * 3.14159265; + for qubit in register { + Rx(theta, qubit); + } + } @EntryPoint() - operation Main() : (Result, Result) { - use q = Qubit(); - (MResetZ(q), MResetZ(q)) + operation Main() : Result[] { + use qs = Qubit[3]; + X(qs[0]); + let iterations = 2; + for _ in 1..iterations { + for q in qs[1...] { + CNOT(qs[0], q); + } + ApplyRotationalIdentity(qs); + } + MResetEachZ(qs) } }"; let qir = compile_source_to_qir(source, *CAPABILITIES); expect![[r#" - %Result = type opaque - %Qubit = type opaque - - @0 = internal constant [4 x i8] c"0_t\00" - @1 = internal constant [6 x i8] c"1_t0r\00" - @2 = internal constant [6 x i8] c"2_t1r\00" + @0 = internal constant [4 x i8] c"0_a\00" + @1 = internal constant [6 x i8] c"1_a0r\00" + @2 = internal constant [6 x i8] c"2_a1r\00" + @3 = internal constant [6 x i8] c"3_a2r\00" + @array0 = internal constant [2 x ptr] [ptr inttoptr (i64 1 to ptr), ptr inttoptr (i64 2 to ptr)] + @array1 = internal constant [3 x ptr] [ptr inttoptr (i64 0 to ptr), ptr inttoptr (i64 1 to ptr), ptr inttoptr (i64 2 to ptr)] define i64 @ENTRYPOINT__main() #0 { block_0: - call void @__quantum__rt__initialize(i8* null) - call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Result* inttoptr (i64 0 to %Result*)) - call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Result* inttoptr (i64 1 to %Result*)) - call void @__quantum__rt__tuple_record_output(i64 2, i8* getelementptr inbounds ([4 x i8], [4 x i8]* @0, i64 0, i64 0)) - call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 0 to %Result*), i8* getelementptr inbounds ([6 x i8], [6 x i8]* @1, i64 0, i64 0)) - call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 1 to %Result*), i8* getelementptr inbounds ([6 x i8], [6 x i8]* @2, i64 0, i64 0)) + %var_1 = alloca i64 + %var_3 = alloca i1 + %var_4 = alloca i64 + %var_9 = alloca i64 + call void @__quantum__rt__initialize(ptr null) + call void @__quantum__qis__x__body(ptr inttoptr (i64 0 to ptr)) + store i64 1, ptr %var_1 + br label %block_1 + block_1: + %var_16 = load i64, ptr %var_1 + %var_2 = icmp sle i64 %var_16, 2 + store i1 true, ptr %var_3 + br i1 %var_2, label %block_2, label %block_3 + block_2: + %var_19 = load i1, ptr %var_3 + br i1 %var_19, label %block_4, label %block_5 + block_3: + store i1 false, ptr %var_3 + br label %block_2 + block_4: + store i64 0, ptr %var_4 + br label %block_6 + block_5: + call void @__quantum__qis__mresetz__body(ptr inttoptr (i64 0 to ptr), ptr inttoptr (i64 0 to ptr)) + call void @__quantum__qis__mresetz__body(ptr inttoptr (i64 1 to ptr), ptr inttoptr (i64 1 to ptr)) + call void @__quantum__qis__mresetz__body(ptr inttoptr (i64 2 to ptr), ptr inttoptr (i64 2 to ptr)) + call void @__quantum__rt__array_record_output(i64 3, ptr @0) + call void @__quantum__rt__result_record_output(ptr inttoptr (i64 0 to ptr), ptr @1) + call void @__quantum__rt__result_record_output(ptr inttoptr (i64 1 to ptr), ptr @2) + call void @__quantum__rt__result_record_output(ptr inttoptr (i64 2 to ptr), ptr @3) ret i64 0 + block_6: + %var_21 = load i64, ptr %var_4 + %var_5 = icmp slt i64 %var_21, 2 + br i1 %var_5, label %block_7, label %block_8 + block_7: + %var_29 = load i64, ptr %var_4 + %var_6 = getelementptr ptr, ptr @array0, i64 %var_29 + %var_30 = load ptr, ptr %var_6 + call void @__quantum__qis__cx__body(ptr inttoptr (i64 0 to ptr), ptr %var_30) + %var_8 = add i64 %var_29, 1 + store i64 %var_8, ptr %var_4 + br label %block_6 + block_8: + store i64 0, ptr %var_9 + br label %block_9 + block_9: + %var_23 = load i64, ptr %var_9 + %var_10 = icmp slt i64 %var_23, 3 + br i1 %var_10, label %block_10, label %block_11 + block_10: + %var_26 = load i64, ptr %var_9 + %var_11 = getelementptr ptr, ptr @array1, i64 %var_26 + %var_27 = load ptr, ptr %var_11 + call void @__quantum__qis__rx__body(double 6.2831853, ptr %var_27) + %var_13 = add i64 %var_26, 1 + store i64 %var_13, ptr %var_9 + br label %block_9 + block_11: + %var_24 = load i64, ptr %var_1 + %var_14 = add i64 %var_24, 1 + store i64 %var_14, ptr %var_1 + br label %block_1 } - declare void @__quantum__rt__initialize(i8*) + declare void @__quantum__rt__initialize(ptr) - declare void @__quantum__qis__mresetz__body(%Qubit*, %Result*) #1 + declare void @__quantum__qis__x__body(ptr) - declare void @__quantum__rt__tuple_record_output(i64, i8*) + declare void @__quantum__qis__cx__body(ptr, ptr) - declare void @__quantum__rt__result_record_output(%Result*, i8*) + declare void @__quantum__qis__rx__body(double, ptr) - attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="1" "required_num_results"="2" } + declare void @__quantum__qis__mresetz__body(ptr, ptr) #1 + + declare void @__quantum__rt__array_record_output(i64, ptr) + + declare void @__quantum__rt__result_record_output(ptr, ptr) + + attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="3" "required_num_results"="3" } attributes #1 = { "irreversible" } ; module flags - !llvm.module.flags = !{!0, !1, !2, !3, !4, !5} + !llvm.module.flags = !{!0, !1, !2, !3, !4, !5, !6, !7} - !0 = !{i32 1, !"qir_major_version", i32 1} - !1 = !{i32 7, !"qir_minor_version", i32 0} + !0 = !{i32 1, !"qir_major_version", i32 2} + !1 = !{i32 7, !"qir_minor_version", i32 1} !2 = !{i32 1, !"dynamic_qubit_management", i1 false} !3 = !{i32 1, !"dynamic_result_management", i1 false} !4 = !{i32 5, !"int_computations", !{!"i64"}} !5 = !{i32 5, !"float_computations", !{!"double"}} - "#]].assert_eq(&qir); + !6 = !{i32 7, !"backwards_branching", i2 3} + !7 = !{i32 1, !"arrays", i1 true} + "#]] + .assert_eq(&qir); } #[test] - fn qubit_measurements_not_deferred() { + fn for_over_qubit_slice_inside_dynamic_while_succeeds() { let source = "namespace Test { + import Std.Intrinsic.*; @EntryPoint() - operation Main() : Result[] { - use (q0, q1) = (Qubit(), Qubit()); - X(q0); - let r0 = MResetZ(q0); - X(q1); - let r1 = MResetZ(q1); - [r0, r1] + operation Main() : Unit { + use qs = Qubit[3]; + mutable done = false; + while not done { + for q in qs[1...] { + CNOT(qs[0], q); + } + set done = MResetZ(qs[0]) == One; + } } }"; let qir = compile_source_to_qir(source, *CAPABILITIES); expect![[r#" - %Result = type opaque - %Qubit = type opaque - - @0 = internal constant [4 x i8] c"0_a\00" - @1 = internal constant [6 x i8] c"1_a0r\00" - @2 = internal constant [6 x i8] c"2_a1r\00" + @0 = internal constant [4 x i8] c"0_t\00" + @array0 = internal constant [2 x ptr] [ptr inttoptr (i64 1 to ptr), ptr inttoptr (i64 2 to ptr)] define i64 @ENTRYPOINT__main() #0 { block_0: - call void @__quantum__rt__initialize(i8* null) - call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 0 to %Qubit*)) - call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Result* inttoptr (i64 0 to %Result*)) - call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 1 to %Qubit*)) - call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 1 to %Qubit*), %Result* inttoptr (i64 1 to %Result*)) - call void @__quantum__rt__array_record_output(i64 2, i8* getelementptr inbounds ([4 x i8], [4 x i8]* @0, i64 0, i64 0)) - call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 0 to %Result*), i8* getelementptr inbounds ([6 x i8], [6 x i8]* @1, i64 0, i64 0)) - call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 1 to %Result*), i8* getelementptr inbounds ([6 x i8], [6 x i8]* @2, i64 0, i64 0)) + %var_1 = alloca i1 + %var_3 = alloca i64 + call void @__quantum__rt__initialize(ptr null) + store i1 false, ptr %var_1 + br label %block_1 + block_1: + %var_10 = load i1, ptr %var_1 + %var_2 = xor i1 %var_10, true + br i1 %var_2, label %block_2, label %block_3 + block_2: + store i64 0, ptr %var_3 + br label %block_4 + block_3: + call void @__quantum__rt__tuple_record_output(i64 0, ptr @0) ret i64 0 + block_4: + %var_12 = load i64, ptr %var_3 + %var_4 = icmp slt i64 %var_12, 2 + br i1 %var_4, label %block_5, label %block_6 + block_5: + %var_14 = load i64, ptr %var_3 + %var_5 = getelementptr ptr, ptr @array0, i64 %var_14 + %var_15 = load ptr, ptr %var_5 + call void @__quantum__qis__cx__body(ptr inttoptr (i64 0 to ptr), ptr %var_15) + %var_7 = add i64 %var_14, 1 + store i64 %var_7, ptr %var_3 + br label %block_4 + block_6: + call void @__quantum__qis__mresetz__body(ptr inttoptr (i64 0 to ptr), ptr inttoptr (i64 0 to ptr)) + %var_8 = call i1 @__quantum__rt__read_result(ptr inttoptr (i64 0 to ptr)) + store i1 %var_8, ptr %var_1 + br label %block_1 } - declare void @__quantum__rt__initialize(i8*) + declare void @__quantum__rt__initialize(ptr) - declare void @__quantum__qis__x__body(%Qubit*) + declare void @__quantum__qis__cx__body(ptr, ptr) - declare void @__quantum__qis__mresetz__body(%Qubit*, %Result*) #1 + declare void @__quantum__qis__mresetz__body(ptr, ptr) #1 - declare void @__quantum__rt__array_record_output(i64, i8*) + declare i1 @__quantum__rt__read_result(ptr) - declare void @__quantum__rt__result_record_output(%Result*, i8*) + declare void @__quantum__rt__tuple_record_output(i64, ptr) - attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="2" "required_num_results"="2" } + attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="3" "required_num_results"="1" } attributes #1 = { "irreversible" } ; module flags - !llvm.module.flags = !{!0, !1, !2, !3, !4, !5} + !llvm.module.flags = !{!0, !1, !2, !3, !4, !5, !6, !7} - !0 = !{i32 1, !"qir_major_version", i32 1} - !1 = !{i32 7, !"qir_minor_version", i32 0} + !0 = !{i32 1, !"qir_major_version", i32 2} + !1 = !{i32 7, !"qir_minor_version", i32 1} !2 = !{i32 1, !"dynamic_qubit_management", i1 false} !3 = !{i32 1, !"dynamic_result_management", i1 false} !4 = !{i32 5, !"int_computations", !{!"i64"}} !5 = !{i32 5, !"float_computations", !{!"double"}} - "#]].assert_eq(&qir); + !6 = !{i32 7, !"backwards_branching", i2 3} + !7 = !{i32 1, !"arrays", i1 true} + "#]] + .assert_eq(&qir); } #[test] - fn qubit_id_swap_results_in_different_id_usage() { + fn result_array_dynamic_index_succeeds() { let source = "namespace Test { + import Std.Intrinsic.*; @EntryPoint() - operation Main() : (Result, Result) { - use (q0, q1) = (Qubit(), Qubit()); - X(q0); - Relabel([q0, q1], [q1, q0]); - X(q1); - (MResetZ(q0), MResetZ(q1)) + operation Main() : Int { + use qs = Qubit[4]; + let results = MResetEachZ(qs); + mutable count = 0; + for i in 0..3 { + if results[i] == One { + set count += 1; + } + } + count } }"; let qir = compile_source_to_qir(source, *CAPABILITIES); expect![[r#" - %Result = type opaque - %Qubit = type opaque - - @0 = internal constant [4 x i8] c"0_t\00" - @1 = internal constant [6 x i8] c"1_t0r\00" - @2 = internal constant [6 x i8] c"2_t1r\00" + @0 = internal constant [4 x i8] c"0_i\00" define i64 @ENTRYPOINT__main() #0 { block_0: - call void @__quantum__rt__initialize(i8* null) - call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 0 to %Qubit*)) - call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 0 to %Qubit*)) - call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 1 to %Qubit*), %Result* inttoptr (i64 0 to %Result*)) - call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Result* inttoptr (i64 1 to %Result*)) - call void @__quantum__rt__tuple_record_output(i64 2, i8* getelementptr inbounds ([4 x i8], [4 x i8]* @0, i64 0, i64 0)) - call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 0 to %Result*), i8* getelementptr inbounds ([6 x i8], [6 x i8]* @1, i64 0, i64 0)) - call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 1 to %Result*), i8* getelementptr inbounds ([6 x i8], [6 x i8]* @2, i64 0, i64 0)) + %var_2 = alloca i64 + call void @__quantum__rt__initialize(ptr null) + call void @__quantum__qis__mresetz__body(ptr inttoptr (i64 0 to ptr), ptr inttoptr (i64 0 to ptr)) + call void @__quantum__qis__mresetz__body(ptr inttoptr (i64 1 to ptr), ptr inttoptr (i64 1 to ptr)) + call void @__quantum__qis__mresetz__body(ptr inttoptr (i64 2 to ptr), ptr inttoptr (i64 2 to ptr)) + call void @__quantum__qis__mresetz__body(ptr inttoptr (i64 3 to ptr), ptr inttoptr (i64 3 to ptr)) + store i64 0, ptr %var_2 + %var_4 = call i1 @__quantum__rt__read_result(ptr inttoptr (i64 0 to ptr)) + br i1 %var_4, label %block_1, label %block_2 + block_1: + %var_24 = load i64, ptr %var_2 + %var_6 = add i64 %var_24, 1 + store i64 %var_6, ptr %var_2 + br label %block_2 + block_2: + %var_7 = call i1 @__quantum__rt__read_result(ptr inttoptr (i64 1 to ptr)) + br i1 %var_7, label %block_3, label %block_4 + block_3: + %var_22 = load i64, ptr %var_2 + %var_9 = add i64 %var_22, 1 + store i64 %var_9, ptr %var_2 + br label %block_4 + block_4: + %var_10 = call i1 @__quantum__rt__read_result(ptr inttoptr (i64 2 to ptr)) + br i1 %var_10, label %block_5, label %block_6 + block_5: + %var_20 = load i64, ptr %var_2 + %var_12 = add i64 %var_20, 1 + store i64 %var_12, ptr %var_2 + br label %block_6 + block_6: + %var_13 = call i1 @__quantum__rt__read_result(ptr inttoptr (i64 3 to ptr)) + br i1 %var_13, label %block_7, label %block_8 + block_7: + %var_18 = load i64, ptr %var_2 + %var_15 = add i64 %var_18, 1 + store i64 %var_15, ptr %var_2 + br label %block_8 + block_8: + %var_17 = load i64, ptr %var_2 + call void @__quantum__rt__int_record_output(i64 %var_17, ptr @0) ret i64 0 } - declare void @__quantum__rt__initialize(i8*) - - declare void @__quantum__qis__x__body(%Qubit*) + declare void @__quantum__rt__initialize(ptr) - declare void @__quantum__qis__mresetz__body(%Qubit*, %Result*) #1 + declare void @__quantum__qis__mresetz__body(ptr, ptr) #1 - declare void @__quantum__rt__tuple_record_output(i64, i8*) + declare i1 @__quantum__rt__read_result(ptr) - declare void @__quantum__rt__result_record_output(%Result*, i8*) + declare void @__quantum__rt__int_record_output(i64, ptr) - attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="2" "required_num_results"="2" } + attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="4" "required_num_results"="4" } attributes #1 = { "irreversible" } ; module flags - !llvm.module.flags = !{!0, !1, !2, !3, !4, !5} + !llvm.module.flags = !{!0, !1, !2, !3, !4, !5, !6, !7} - !0 = !{i32 1, !"qir_major_version", i32 1} - !1 = !{i32 7, !"qir_minor_version", i32 0} + !0 = !{i32 1, !"qir_major_version", i32 2} + !1 = !{i32 7, !"qir_minor_version", i32 1} !2 = !{i32 1, !"dynamic_qubit_management", i1 false} !3 = !{i32 1, !"dynamic_result_management", i1 false} !4 = !{i32 5, !"int_computations", !{!"i64"}} !5 = !{i32 5, !"float_computations", !{!"double"}} - "#]].assert_eq(&qir); + !6 = !{i32 7, !"backwards_branching", i2 3} + !7 = !{i32 1, !"arrays", i1 true} + "#]] + .assert_eq(&qir); } #[test] - fn qubit_id_swap_across_reset_uses_updated_ids() { + fn result_array_while_loop_dynamic_index_succeeds() { let source = "namespace Test { - @EntryPoint() - operation Main() : (Result, Result) { - { - use (q0, q1) = (Qubit(), Qubit()); - X(q0); - Relabel([q0, q1], [q1, q0]); - X(q1); - Reset(q0); - Reset(q1); + import Std.Intrinsic.*; + @EntryPoint() + operation Main() : Int { + use qs = Qubit[4]; + H(qs[0]); + H(qs[1]); + H(qs[2]); + H(qs[3]); + let r0 = MResetZ(qs[0]); + let r1 = MResetZ(qs[1]); + let r2 = MResetZ(qs[2]); + let r3 = MResetZ(qs[3]); + let results = [r0, r1, r2, r3]; + mutable count = 0; + mutable i = 0; + while i < 4 { + if results[i] == One { set count += 1; } + set i += 1; } - use (q0, q1) = (Qubit(), Qubit()); - (MResetZ(q0), MResetZ(q1)) + count } }"; let qir = compile_source_to_qir(source, *CAPABILITIES); expect![[r#" - %Result = type opaque - %Qubit = type opaque - - @0 = internal constant [4 x i8] c"0_t\00" - @1 = internal constant [6 x i8] c"1_t0r\00" - @2 = internal constant [6 x i8] c"2_t1r\00" + @0 = internal constant [4 x i8] c"0_i\00" define i64 @ENTRYPOINT__main() #0 { block_0: - call void @__quantum__rt__initialize(i8* null) - call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 0 to %Qubit*)) - call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 0 to %Qubit*)) - call void @__quantum__qis__reset__body(%Qubit* inttoptr (i64 1 to %Qubit*)) - call void @__quantum__qis__reset__body(%Qubit* inttoptr (i64 0 to %Qubit*)) - call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Result* inttoptr (i64 0 to %Result*)) - call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 1 to %Qubit*), %Result* inttoptr (i64 1 to %Result*)) - call void @__quantum__rt__tuple_record_output(i64 2, i8* getelementptr inbounds ([4 x i8], [4 x i8]* @0, i64 0, i64 0)) - call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 0 to %Result*), i8* getelementptr inbounds ([6 x i8], [6 x i8]* @1, i64 0, i64 0)) - call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 1 to %Result*), i8* getelementptr inbounds ([6 x i8], [6 x i8]* @2, i64 0, i64 0)) + %var_1 = alloca i64 + call void @__quantum__rt__initialize(ptr null) + call void @__quantum__qis__h__body(ptr inttoptr (i64 0 to ptr)) + call void @__quantum__qis__h__body(ptr inttoptr (i64 1 to ptr)) + call void @__quantum__qis__h__body(ptr inttoptr (i64 2 to ptr)) + call void @__quantum__qis__h__body(ptr inttoptr (i64 3 to ptr)) + call void @__quantum__qis__mresetz__body(ptr inttoptr (i64 0 to ptr), ptr inttoptr (i64 0 to ptr)) + call void @__quantum__qis__mresetz__body(ptr inttoptr (i64 1 to ptr), ptr inttoptr (i64 1 to ptr)) + call void @__quantum__qis__mresetz__body(ptr inttoptr (i64 2 to ptr), ptr inttoptr (i64 2 to ptr)) + call void @__quantum__qis__mresetz__body(ptr inttoptr (i64 3 to ptr), ptr inttoptr (i64 3 to ptr)) + store i64 0, ptr %var_1 + %var_3 = call i1 @__quantum__rt__read_result(ptr inttoptr (i64 0 to ptr)) + br i1 %var_3, label %block_1, label %block_2 + block_1: + %var_23 = load i64, ptr %var_1 + %var_5 = add i64 %var_23, 1 + store i64 %var_5, ptr %var_1 + br label %block_2 + block_2: + %var_6 = call i1 @__quantum__rt__read_result(ptr inttoptr (i64 1 to ptr)) + br i1 %var_6, label %block_3, label %block_4 + block_3: + %var_21 = load i64, ptr %var_1 + %var_8 = add i64 %var_21, 1 + store i64 %var_8, ptr %var_1 + br label %block_4 + block_4: + %var_9 = call i1 @__quantum__rt__read_result(ptr inttoptr (i64 2 to ptr)) + br i1 %var_9, label %block_5, label %block_6 + block_5: + %var_19 = load i64, ptr %var_1 + %var_11 = add i64 %var_19, 1 + store i64 %var_11, ptr %var_1 + br label %block_6 + block_6: + %var_12 = call i1 @__quantum__rt__read_result(ptr inttoptr (i64 3 to ptr)) + br i1 %var_12, label %block_7, label %block_8 + block_7: + %var_17 = load i64, ptr %var_1 + %var_14 = add i64 %var_17, 1 + store i64 %var_14, ptr %var_1 + br label %block_8 + block_8: + %var_16 = load i64, ptr %var_1 + call void @__quantum__rt__int_record_output(i64 %var_16, ptr @0) ret i64 0 } - declare void @__quantum__rt__initialize(i8*) - - declare void @__quantum__qis__x__body(%Qubit*) + declare void @__quantum__rt__initialize(ptr) - declare void @__quantum__qis__reset__body(%Qubit*) #1 + declare void @__quantum__qis__h__body(ptr) - declare void @__quantum__qis__mresetz__body(%Qubit*, %Result*) #1 + declare void @__quantum__qis__mresetz__body(ptr, ptr) #1 - declare void @__quantum__rt__tuple_record_output(i64, i8*) + declare i1 @__quantum__rt__read_result(ptr) - declare void @__quantum__rt__result_record_output(%Result*, i8*) + declare void @__quantum__rt__int_record_output(i64, ptr) - attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="2" "required_num_results"="2" } + attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="4" "required_num_results"="4" } attributes #1 = { "irreversible" } ; module flags - !llvm.module.flags = !{!0, !1, !2, !3, !4, !5} + !llvm.module.flags = !{!0, !1, !2, !3, !4, !5, !6, !7} - !0 = !{i32 1, !"qir_major_version", i32 1} - !1 = !{i32 7, !"qir_minor_version", i32 0} + !0 = !{i32 1, !"qir_major_version", i32 2} + !1 = !{i32 7, !"qir_minor_version", i32 1} !2 = !{i32 1, !"dynamic_qubit_management", i1 false} !3 = !{i32 1, !"dynamic_result_management", i1 false} !4 = !{i32 5, !"int_computations", !{!"i64"}} !5 = !{i32 5, !"float_computations", !{!"double"}} - "#]].assert_eq(&qir); + !6 = !{i32 7, !"backwards_branching", i2 3} + !7 = !{i32 1, !"arrays", i1 true} + "#]] + .assert_eq(&qir); } #[test] - fn qubit_id_swap_with_out_of_order_release_uses_correct_ids() { + #[ignore = "CapabilitiesCk(UseOfDynamicResult) — mutable Result re-measurement requires UseOfDynamicResult, not in RIFLA profile"] + fn mutable_result_variable_succeeds() { let source = "namespace Test { + import Std.Intrinsic.*; @EntryPoint() - operation Main() : (Result, Result) { - let q0 = QIR.Runtime.__quantum__rt__qubit_allocate(); - let q1 = QIR.Runtime.__quantum__rt__qubit_allocate(); - let q2 = QIR.Runtime.__quantum__rt__qubit_allocate(); - X(q0); - X(q1); - X(q2); - Relabel([q0, q1], [q1, q0]); - QIR.Runtime.__quantum__rt__qubit_release(q0); - let q3 = QIR.Runtime.__quantum__rt__qubit_allocate(); - X(q3); - (MResetZ(q3), MResetZ(q1)) + operation Main() : Result { + use q = Qubit(); + H(q); + mutable r = M(q); + if r == One { + X(q); + set r = M(q); + } + r + } + }"; + let qir = compile_source_to_qir_result(source, *CAPABILITIES) + .expect("mutable Result variable should compile"); + assert!(qir.contains("@ENTRYPOINT__main")); + } + + #[test] + fn for_loop_over_qubits_with_reset_all_succeeds() { + let source = "namespace Test { + import Std.Intrinsic.*; + @EntryPoint() + operation Main() : Result { + use qs = Qubit[4]; + for q in qs { + H(q); + } + let r = MResetZ(qs[0]); + ResetAll(qs[1..3]); + r } }"; let qir = compile_source_to_qir(source, *CAPABILITIES); expect![[r#" - %Result = type opaque - %Qubit = type opaque - - @0 = internal constant [4 x i8] c"0_t\00" - @1 = internal constant [6 x i8] c"1_t0r\00" - @2 = internal constant [6 x i8] c"2_t1r\00" + @0 = internal constant [4 x i8] c"0_r\00" + @array0 = internal constant [4 x ptr] [ptr inttoptr (i64 0 to ptr), ptr inttoptr (i64 1 to ptr), ptr inttoptr (i64 2 to ptr), ptr inttoptr (i64 3 to ptr)] + @array1 = internal constant [3 x ptr] [ptr inttoptr (i64 1 to ptr), ptr inttoptr (i64 2 to ptr), ptr inttoptr (i64 3 to ptr)] define i64 @ENTRYPOINT__main() #0 { block_0: - call void @__quantum__rt__initialize(i8* null) - call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 0 to %Qubit*)) - call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 1 to %Qubit*)) - call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 2 to %Qubit*)) - call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 1 to %Qubit*)) - call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 1 to %Qubit*), %Result* inttoptr (i64 0 to %Result*)) - call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Result* inttoptr (i64 1 to %Result*)) - call void @__quantum__rt__tuple_record_output(i64 2, i8* getelementptr inbounds ([4 x i8], [4 x i8]* @0, i64 0, i64 0)) - call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 0 to %Result*), i8* getelementptr inbounds ([6 x i8], [6 x i8]* @1, i64 0, i64 0)) - call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 1 to %Result*), i8* getelementptr inbounds ([6 x i8], [6 x i8]* @2, i64 0, i64 0)) + %var_1 = alloca i64 + %var_6 = alloca i64 + call void @__quantum__rt__initialize(ptr null) + store i64 0, ptr %var_1 + br label %block_1 + block_1: + %var_12 = load i64, ptr %var_1 + %var_2 = icmp slt i64 %var_12, 4 + br i1 %var_2, label %block_2, label %block_3 + block_2: + %var_18 = load i64, ptr %var_1 + %var_3 = getelementptr ptr, ptr @array0, i64 %var_18 + %var_19 = load ptr, ptr %var_3 + call void @__quantum__qis__h__body(ptr %var_19) + %var_5 = add i64 %var_18, 1 + store i64 %var_5, ptr %var_1 + br label %block_1 + block_3: + call void @__quantum__qis__mresetz__body(ptr inttoptr (i64 0 to ptr), ptr inttoptr (i64 0 to ptr)) + store i64 0, ptr %var_6 + br label %block_4 + block_4: + %var_14 = load i64, ptr %var_6 + %var_7 = icmp slt i64 %var_14, 3 + br i1 %var_7, label %block_5, label %block_6 + block_5: + %var_15 = load i64, ptr %var_6 + %var_8 = getelementptr ptr, ptr @array1, i64 %var_15 + %var_16 = load ptr, ptr %var_8 + call void @__quantum__qis__reset__body(ptr %var_16) + %var_10 = add i64 %var_15, 1 + store i64 %var_10, ptr %var_6 + br label %block_4 + block_6: + call void @__quantum__rt__result_record_output(ptr inttoptr (i64 0 to ptr), ptr @0) ret i64 0 } - declare void @__quantum__rt__initialize(i8*) + declare void @__quantum__rt__initialize(ptr) - declare void @__quantum__qis__x__body(%Qubit*) + declare void @__quantum__qis__h__body(ptr) - declare void @__quantum__qis__mresetz__body(%Qubit*, %Result*) #1 + declare void @__quantum__qis__mresetz__body(ptr, ptr) #1 - declare void @__quantum__rt__tuple_record_output(i64, i8*) + declare void @__quantum__qis__reset__body(ptr) #1 - declare void @__quantum__rt__result_record_output(%Result*, i8*) + declare void @__quantum__rt__result_record_output(ptr, ptr) - attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="3" "required_num_results"="2" } + attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="4" "required_num_results"="1" } attributes #1 = { "irreversible" } ; module flags - !llvm.module.flags = !{!0, !1, !2, !3, !4, !5} + !llvm.module.flags = !{!0, !1, !2, !3, !4, !5, !6, !7} - !0 = !{i32 1, !"qir_major_version", i32 1} - !1 = !{i32 7, !"qir_minor_version", i32 0} + !0 = !{i32 1, !"qir_major_version", i32 2} + !1 = !{i32 7, !"qir_minor_version", i32 1} !2 = !{i32 1, !"dynamic_qubit_management", i1 false} !3 = !{i32 1, !"dynamic_result_management", i1 false} !4 = !{i32 5, !"int_computations", !{!"i64"}} !5 = !{i32 5, !"float_computations", !{!"double"}} - "#]].assert_eq(&qir); + !6 = !{i32 7, !"backwards_branching", i2 3} + !7 = !{i32 1, !"arrays", i1 true} + "#]] + .assert_eq(&qir); } #[test] - fn dynamic_integer_with_branch_and_phi_supported() { + fn measure_each_z_static_qubits_succeeds() { let source = "namespace Test { + import Std.Intrinsic.*; @EntryPoint() - operation Main() : Int { - use q = Qubit(); - H(q); - MResetZ(q) == Zero ? 0 | 1 + operation Main() : Result[] { + use qs = Qubit[3]; + X(qs[0]); + H(qs[1]); + MResetEachZ(qs) } }"; let qir = compile_source_to_qir(source, *CAPABILITIES); expect![[r#" - %Result = type opaque - %Qubit = type opaque - - @0 = internal constant [4 x i8] c"0_i\00" + @0 = internal constant [4 x i8] c"0_a\00" + @1 = internal constant [6 x i8] c"1_a0r\00" + @2 = internal constant [6 x i8] c"2_a1r\00" + @3 = internal constant [6 x i8] c"3_a2r\00" define i64 @ENTRYPOINT__main() #0 { block_0: - call void @__quantum__rt__initialize(i8* null) - call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 0 to %Qubit*)) - call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Result* inttoptr (i64 0 to %Result*)) - %var_0 = call i1 @__quantum__rt__read_result(%Result* inttoptr (i64 0 to %Result*)) - %var_1 = icmp eq i1 %var_0, false - br i1 %var_1, label %block_1, label %block_2 - block_1: - br label %block_3 - block_2: - br label %block_3 - block_3: - %var_4 = phi i64 [0, %block_1], [1, %block_2] - call void @__quantum__rt__int_record_output(i64 %var_4, i8* getelementptr inbounds ([4 x i8], [4 x i8]* @0, i64 0, i64 0)) + call void @__quantum__rt__initialize(ptr null) + call void @__quantum__qis__x__body(ptr inttoptr (i64 0 to ptr)) + call void @__quantum__qis__h__body(ptr inttoptr (i64 1 to ptr)) + call void @__quantum__qis__mresetz__body(ptr inttoptr (i64 0 to ptr), ptr inttoptr (i64 0 to ptr)) + call void @__quantum__qis__mresetz__body(ptr inttoptr (i64 1 to ptr), ptr inttoptr (i64 1 to ptr)) + call void @__quantum__qis__mresetz__body(ptr inttoptr (i64 2 to ptr), ptr inttoptr (i64 2 to ptr)) + call void @__quantum__rt__array_record_output(i64 3, ptr @0) + call void @__quantum__rt__result_record_output(ptr inttoptr (i64 0 to ptr), ptr @1) + call void @__quantum__rt__result_record_output(ptr inttoptr (i64 1 to ptr), ptr @2) + call void @__quantum__rt__result_record_output(ptr inttoptr (i64 2 to ptr), ptr @3) ret i64 0 } - declare void @__quantum__rt__initialize(i8*) + declare void @__quantum__rt__initialize(ptr) - declare void @__quantum__qis__h__body(%Qubit*) + declare void @__quantum__qis__x__body(ptr) - declare void @__quantum__qis__mresetz__body(%Qubit*, %Result*) #1 + declare void @__quantum__qis__h__body(ptr) - declare i1 @__quantum__rt__read_result(%Result*) + declare void @__quantum__qis__mresetz__body(ptr, ptr) #1 - declare void @__quantum__rt__int_record_output(i64, i8*) + declare void @__quantum__rt__array_record_output(i64, ptr) - attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="1" "required_num_results"="1" } + declare void @__quantum__rt__result_record_output(ptr, ptr) + + attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="3" "required_num_results"="3" } attributes #1 = { "irreversible" } ; module flags - !llvm.module.flags = !{!0, !1, !2, !3, !4, !5} + !llvm.module.flags = !{!0, !1, !2, !3, !4, !5, !6, !7} - !0 = !{i32 1, !"qir_major_version", i32 1} - !1 = !{i32 7, !"qir_minor_version", i32 0} + !0 = !{i32 1, !"qir_major_version", i32 2} + !1 = !{i32 7, !"qir_minor_version", i32 1} !2 = !{i32 1, !"dynamic_qubit_management", i1 false} !3 = !{i32 1, !"dynamic_result_management", i1 false} !4 = !{i32 5, !"int_computations", !{!"i64"}} !5 = !{i32 5, !"float_computations", !{!"double"}} - "#]].assert_eq(&qir); + !6 = !{i32 7, !"backwards_branching", i2 3} + !7 = !{i32 1, !"arrays", i1 true} + "#]] + .assert_eq(&qir); } #[test] - fn dynamic_double_with_branch_and_phi_supported() { + fn static_while_inside_emit_while_succeeds() { let source = "namespace Test { + import Std.Intrinsic.*; @EntryPoint() - operation Main() : Double { + operation Main() : Int { use q = Qubit(); - H(q); - MResetZ(q) == Zero ? 0.0 | 1.0 + mutable total = 0; + while MResetZ(q) == One { + mutable idx = 0; + while idx < 3 { + set total += 1; + set idx += 1; + } + } + total } }"; let qir = compile_source_to_qir(source, *CAPABILITIES); expect![[r#" - %Result = type opaque - %Qubit = type opaque - - @0 = internal constant [4 x i8] c"0_d\00" + @0 = internal constant [4 x i8] c"0_i\00" define i64 @ENTRYPOINT__main() #0 { block_0: - call void @__quantum__rt__initialize(i8* null) - call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 0 to %Qubit*)) - call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Result* inttoptr (i64 0 to %Result*)) - %var_0 = call i1 @__quantum__rt__read_result(%Result* inttoptr (i64 0 to %Result*)) - %var_1 = icmp eq i1 %var_0, false - br i1 %var_1, label %block_1, label %block_2 + %var_0 = alloca i64 + %var_3 = alloca i64 + call void @__quantum__rt__initialize(ptr null) + store i64 0, ptr %var_0 + br label %block_1 block_1: - br label %block_3 + call void @__quantum__qis__mresetz__body(ptr inttoptr (i64 0 to ptr), ptr inttoptr (i64 0 to ptr)) + %var_1 = call i1 @__quantum__rt__read_result(ptr inttoptr (i64 0 to ptr)) + br i1 %var_1, label %block_2, label %block_3 block_2: - br label %block_3 + store i64 0, ptr %var_3 + br label %block_4 block_3: - %var_4 = phi double [0.0, %block_1], [1.0, %block_2] - call void @__quantum__rt__double_record_output(double %var_4, i8* getelementptr inbounds ([4 x i8], [4 x i8]* @0, i64 0, i64 0)) + %var_8 = load i64, ptr %var_0 + call void @__quantum__rt__int_record_output(i64 %var_8, ptr @0) ret i64 0 + block_4: + %var_10 = load i64, ptr %var_3 + %var_4 = icmp slt i64 %var_10, 3 + br i1 %var_4, label %block_5, label %block_6 + block_5: + %var_11 = load i64, ptr %var_0 + %var_5 = add i64 %var_11, 1 + store i64 %var_5, ptr %var_0 + %var_13 = load i64, ptr %var_3 + %var_6 = add i64 %var_13, 1 + store i64 %var_6, ptr %var_3 + br label %block_4 + block_6: + br label %block_1 } - declare void @__quantum__rt__initialize(i8*) - - declare void @__quantum__qis__h__body(%Qubit*) + declare void @__quantum__rt__initialize(ptr) - declare void @__quantum__qis__mresetz__body(%Qubit*, %Result*) #1 + declare void @__quantum__qis__mresetz__body(ptr, ptr) #1 - declare i1 @__quantum__rt__read_result(%Result*) + declare i1 @__quantum__rt__read_result(ptr) - declare void @__quantum__rt__double_record_output(double, i8*) + declare void @__quantum__rt__int_record_output(i64, ptr) attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="1" "required_num_results"="1" } attributes #1 = { "irreversible" } ; module flags - !llvm.module.flags = !{!0, !1, !2, !3, !4, !5} + !llvm.module.flags = !{!0, !1, !2, !3, !4, !5, !6, !7} - !0 = !{i32 1, !"qir_major_version", i32 1} - !1 = !{i32 7, !"qir_minor_version", i32 0} + !0 = !{i32 1, !"qir_major_version", i32 2} + !1 = !{i32 7, !"qir_minor_version", i32 1} !2 = !{i32 1, !"dynamic_qubit_management", i1 false} !3 = !{i32 1, !"dynamic_result_management", i1 false} !4 = !{i32 5, !"int_computations", !{!"i64"}} !5 = !{i32 5, !"float_computations", !{!"double"}} - "#]].assert_eq(&qir); + !6 = !{i32 7, !"backwards_branching", i2 3} + !7 = !{i32 1, !"arrays", i1 true} + "#]] + .assert_eq(&qir); } #[test] - fn custom_reset_generates_correct_qir() { + fn nested_emit_while_loops_succeeds() { let source = "namespace Test { - operation Main() : Result { - use q = Qubit(); - __quantum__qis__custom_reset__body(q); - M(q) - } - - @Reset() - operation __quantum__qis__custom_reset__body(target: Qubit) : Unit { - body intrinsic; + import Std.Intrinsic.*; + @EntryPoint() + operation Main() : Int { + use qs = Qubit[2]; + mutable outer = 0; + while outer < 3 { + H(qs[0]); + mutable inner = 0; + while inner < 2 { + H(qs[1]); + set inner += 1; + } + set outer += 1; + } + outer } }"; let qir = compile_source_to_qir(source, *CAPABILITIES); expect![[r#" - %Result = type opaque - %Qubit = type opaque - - @0 = internal constant [4 x i8] c"0_r\00" + @0 = internal constant [4 x i8] c"0_i\00" define i64 @ENTRYPOINT__main() #0 { block_0: - call void @__quantum__rt__initialize(i8* null) - call void @__quantum__qis__custom_reset__body(%Qubit* inttoptr (i64 0 to %Qubit*)) - call void @__quantum__qis__m__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Result* inttoptr (i64 0 to %Result*)) - call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 0 to %Result*), i8* getelementptr inbounds ([4 x i8], [4 x i8]* @0, i64 0, i64 0)) + %var_1 = alloca i64 + %var_3 = alloca i64 + call void @__quantum__rt__initialize(ptr null) + store i64 0, ptr %var_1 + br label %block_1 + block_1: + %var_8 = load i64, ptr %var_1 + %var_2 = icmp slt i64 %var_8, 3 + br i1 %var_2, label %block_2, label %block_3 + block_2: + call void @__quantum__qis__h__body(ptr inttoptr (i64 0 to ptr)) + store i64 0, ptr %var_3 + br label %block_4 + block_3: + %var_9 = load i64, ptr %var_1 + call void @__quantum__rt__int_record_output(i64 %var_9, ptr @0) ret i64 0 + block_4: + %var_11 = load i64, ptr %var_3 + %var_4 = icmp slt i64 %var_11, 2 + br i1 %var_4, label %block_5, label %block_6 + block_5: + call void @__quantum__qis__h__body(ptr inttoptr (i64 1 to ptr)) + %var_14 = load i64, ptr %var_3 + %var_5 = add i64 %var_14, 1 + store i64 %var_5, ptr %var_3 + br label %block_4 + block_6: + %var_12 = load i64, ptr %var_1 + %var_6 = add i64 %var_12, 1 + store i64 %var_6, ptr %var_1 + br label %block_1 } - declare void @__quantum__rt__initialize(i8*) - - declare void @__quantum__qis__custom_reset__body(%Qubit*) #1 + declare void @__quantum__rt__initialize(ptr) - declare void @__quantum__qis__m__body(%Qubit*, %Result*) #1 + declare void @__quantum__qis__h__body(ptr) - declare void @__quantum__rt__result_record_output(%Result*, i8*) + declare void @__quantum__rt__int_record_output(i64, ptr) - attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="1" "required_num_results"="1" } + attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="2" "required_num_results"="0" } attributes #1 = { "irreversible" } ; module flags - !llvm.module.flags = !{!0, !1, !2, !3, !4, !5} + !llvm.module.flags = !{!0, !1, !2, !3, !4, !5, !6, !7} - !0 = !{i32 1, !"qir_major_version", i32 1} - !1 = !{i32 7, !"qir_minor_version", i32 0} + !0 = !{i32 1, !"qir_major_version", i32 2} + !1 = !{i32 7, !"qir_minor_version", i32 1} !2 = !{i32 1, !"dynamic_qubit_management", i1 false} !3 = !{i32 1, !"dynamic_result_management", i1 false} !4 = !{i32 5, !"int_computations", !{!"i64"}} !5 = !{i32 5, !"float_computations", !{!"double"}} + !6 = !{i32 7, !"backwards_branching", i2 3} + !7 = !{i32 1, !"arrays", i1 true} "#]] - .assert_eq(&qir); + .assert_eq(&qir); } #[test] - fn dynamic_double_intrinsic() { + fn for_loop_over_qubits_with_dynamic_exit_succeeds() { let source = "namespace Test { - operation OpA(theta: Double, q : Qubit) : Unit { body intrinsic; } + import Std.Intrinsic.*; @EntryPoint() - operation Main() : Double { - use q = Qubit(); - H(q); - let theta = MResetZ(q) == Zero ? 0.0 | 1.0; - OpA(1.0 + theta, q); - Rx(2.0 * theta, q); - Ry(theta / 3.0, q); - Rz(theta - 4.0, q); - OpA(theta, q); - Rx(theta, q); - theta + operation Main() : Bool { + use qs = Qubit[3]; + mutable found = false; + for q in qs { + H(q); + if MResetZ(q) == One { + set found = true; + } + } + found } }"; let qir = compile_source_to_qir(source, *CAPABILITIES); expect![[r#" - %Result = type opaque - %Qubit = type opaque - - @0 = internal constant [4 x i8] c"0_d\00" + @0 = internal constant [4 x i8] c"0_b\00" + @array0 = internal constant [3 x ptr] [ptr inttoptr (i64 0 to ptr), ptr inttoptr (i64 1 to ptr), ptr inttoptr (i64 2 to ptr)] define i64 @ENTRYPOINT__main() #0 { block_0: - call void @__quantum__rt__initialize(i8* null) - call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 0 to %Qubit*)) - call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Result* inttoptr (i64 0 to %Result*)) - %var_0 = call i1 @__quantum__rt__read_result(%Result* inttoptr (i64 0 to %Result*)) - %var_1 = icmp eq i1 %var_0, false - br i1 %var_1, label %block_1, label %block_2 + %var_1 = alloca i1 + %var_2 = alloca i64 + call void @__quantum__rt__initialize(ptr null) + store i1 false, ptr %var_1 + store i64 0, ptr %var_2 + br label %block_1 block_1: - br label %block_3 + %var_11 = load i64, ptr %var_2 + %var_3 = icmp slt i64 %var_11, 3 + br i1 %var_3, label %block_2, label %block_3 block_2: - br label %block_3 + %var_13 = load i64, ptr %var_2 + %var_4 = getelementptr ptr, ptr @array0, i64 %var_13 + %var_14 = load ptr, ptr %var_4 + call void @__quantum__qis__h__body(ptr %var_14) + call void @__quantum__qis__mresetz__body(ptr %var_14, ptr inttoptr (i64 0 to ptr)) + %var_6 = call i1 @__quantum__rt__read_result(ptr inttoptr (i64 0 to ptr)) + br i1 %var_6, label %block_4, label %block_5 block_3: - %var_9 = phi double [0.0, %block_1], [1.0, %block_2] - %var_4 = fadd double 1.0, %var_9 - call void @OpA(double %var_4, %Qubit* inttoptr (i64 0 to %Qubit*)) - %var_5 = fmul double 2.0, %var_9 - call void @__quantum__qis__rx__body(double %var_5, %Qubit* inttoptr (i64 0 to %Qubit*)) - %var_6 = fdiv double %var_9, 3.0 - call void @__quantum__qis__ry__body(double %var_6, %Qubit* inttoptr (i64 0 to %Qubit*)) - %var_7 = fsub double %var_9, 4.0 - call void @__quantum__qis__rz__body(double %var_7, %Qubit* inttoptr (i64 0 to %Qubit*)) - call void @OpA(double %var_9, %Qubit* inttoptr (i64 0 to %Qubit*)) - call void @__quantum__qis__rx__body(double %var_9, %Qubit* inttoptr (i64 0 to %Qubit*)) - call void @__quantum__rt__double_record_output(double %var_9, i8* getelementptr inbounds ([4 x i8], [4 x i8]* @0, i64 0, i64 0)) + %var_12 = load i1, ptr %var_1 + call void @__quantum__rt__bool_record_output(i1 %var_12, ptr @0) ret i64 0 + block_4: + store i1 true, ptr %var_1 + br label %block_5 + block_5: + %var_15 = load i64, ptr %var_2 + %var_8 = add i64 %var_15, 1 + store i64 %var_8, ptr %var_2 + br label %block_1 } - declare void @__quantum__rt__initialize(i8*) - - declare void @__quantum__qis__h__body(%Qubit*) - - declare void @__quantum__qis__mresetz__body(%Qubit*, %Result*) #1 - - declare i1 @__quantum__rt__read_result(%Result*) - - declare void @OpA(double, %Qubit*) + declare void @__quantum__rt__initialize(ptr) - declare void @__quantum__qis__rx__body(double, %Qubit*) + declare void @__quantum__qis__h__body(ptr) - declare void @__quantum__qis__ry__body(double, %Qubit*) + declare void @__quantum__qis__mresetz__body(ptr, ptr) #1 - declare void @__quantum__qis__rz__body(double, %Qubit*) + declare i1 @__quantum__rt__read_result(ptr) - declare void @__quantum__rt__double_record_output(double, i8*) + declare void @__quantum__rt__bool_record_output(i1, ptr) - attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="1" "required_num_results"="1" } + attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="3" "required_num_results"="1" } attributes #1 = { "irreversible" } ; module flags - !llvm.module.flags = !{!0, !1, !2, !3, !4, !5} + !llvm.module.flags = !{!0, !1, !2, !3, !4, !5, !6, !7} - !0 = !{i32 1, !"qir_major_version", i32 1} - !1 = !{i32 7, !"qir_minor_version", i32 0} + !0 = !{i32 1, !"qir_major_version", i32 2} + !1 = !{i32 7, !"qir_minor_version", i32 1} !2 = !{i32 1, !"dynamic_qubit_management", i1 false} !3 = !{i32 1, !"dynamic_result_management", i1 false} !4 = !{i32 5, !"int_computations", !{!"i64"}} !5 = !{i32 5, !"float_computations", !{!"double"}} - "#]].assert_eq(&qir); + !6 = !{i32 7, !"backwards_branching", i2 3} + !7 = !{i32 1, !"arrays", i1 true} + "#]] + .assert_eq(&qir); } } diff --git a/source/compiler/qsc/src/compile.rs b/source/compiler/qsc/src/compile.rs index 2feb3b41a7..c95bf4ea83 100644 --- a/source/compiler/qsc/src/compile.rs +++ b/source/compiler/qsc/src/compile.rs @@ -29,6 +29,11 @@ pub enum ErrorKind { #[diagnostic(transparent)] Pass(#[from] qsc_passes::Error), + /// Errors from FIR-level transforms (return unification, defunctionalization, + /// monomorphization) that run before capability checking. + #[diagnostic(transparent)] + FirTransform(#[from] qsc_fir_transforms::PipelineError), + /// `Lint` variant represents lints generated during the linting stage. These diagnostics are /// typically emitted from the language server and happens after all other compilation passes. #[diagnostic(transparent)] diff --git a/source/compiler/qsc/src/interpret.rs b/source/compiler/qsc/src/interpret.rs index 8e4a2b982a..1136d368d8 100644 --- a/source/compiler/qsc/src/interpret.rs +++ b/source/compiler/qsc/src/interpret.rs @@ -16,6 +16,10 @@ mod tests; use std::{cell::RefCell, rc::Rc}; use crate::{ + codegen::qir::{ + CodegenFir, entry_from_codegen_fir, prepare_codegen_fir, + prepare_codegen_fir_from_callable_args, prepare_codegen_fir_from_fir_store, + }, error::{self, WithStack}, incremental::Compiler, location::Location, @@ -74,10 +78,10 @@ use qsc_lowerer::{ map_fir_local_item_to_hir, map_fir_package_to_hir, map_hir_local_item_to_fir, map_hir_package_to_fir, }; -use qsc_partial_eval::{PartialEvalConfig, ProgramEntry}; +use qsc_partial_eval::PartialEvalConfig; use qsc_passes::{PackageType, PassContext}; use qsc_rca::PackageStoreComputeProperties; -use rustc_hash::FxHashSet; +use rustc_hash::FxHashMap; use thiserror::Error; impl Error { @@ -120,6 +124,9 @@ pub enum Error { #[error("partial evaluation error")] #[diagnostic(transparent)] PartialEvaluation(#[from] WithSource), + #[error("FIR transform error")] + #[diagnostic(transparent)] + FirTransform(#[from] WithSource), } /// A Q# interpreter. @@ -128,8 +135,6 @@ pub struct Interpreter { compiler: Compiler, /// The target capabilities used for compilation. capabilities: TargetCapabilityFlags, - /// The computed properties for the package store, if any, used for code generation. - compute_properties: Option, /// The number of lines that have so far been compiled. /// This field is used to generate a unique label /// for each line evaluated with `eval_fragments`. @@ -339,10 +344,14 @@ impl Interpreter { let package_id = compiler.package_id(); let package = map_hir_package_to_fir(package_id); - let compute_properties = if capabilities == TargetCapabilityFlags::all() { - None - } else { - let compute_properties = PassContext::run_fir_passes_on_fir( + + // Run RCA early to surface capability violations at interpreter construction + // time rather than deferring to qirgen()/circuit(). The computed properties + // are intentionally discarded — only the `?` error propagation is used. + // Caching would not help because later backend paths clone the FIR store, + // run the transform pipeline, and re-run RCA on the transformed store. + if capabilities != TargetCapabilityFlags::all() { + let _compute_properties = PassContext::run_fir_passes_on_fir( &fir_store, map_hir_package_to_fir(source_package_id), capabilities, @@ -358,15 +367,12 @@ impl Interpreter { .map(|error| Error::Pass(WithSource::from_map(&source_package.sources, error))) .collect::>() })?; - - Some(compute_properties) - }; + } Ok(Self { compiler, lines: 0, capabilities, - compute_properties, fir_store, lowerer: qsc_lowerer::Lowerer::new(), expr_graph: None, @@ -908,6 +914,188 @@ impl Interpreter { .snapshot(&(self.compiler.package_store(), &self.fir_store)) } + fn prepare_codegen_entry_expr( + &mut self, + expr: &str, + ) -> std::result::Result> { + if self.entry_point_call_expr().as_deref() == Some(expr) { + return self.prepare_codegen_source_package(); + } + + let _ = self.compile_entry_expr(expr)?; + + prepare_codegen_fir_from_fir_store( + self.compiler.package_store(), + map_fir_package_to_hir(self.package), + &self.fir_store, + self.package, + self.capabilities, + ) + } + + fn prepare_codegen_source_package(&self) -> std::result::Result> { + prepare_codegen_fir( + self.compiler.package_store(), + map_fir_package_to_hir(self.source_package), + self.capabilities, + ) + } + + /// Reconstructs the source package's `@EntryPoint` callable as a Q# call + /// expression string (e.g., `"MyNamespace.MyOp()"`). + /// + /// Returns `Some` only when the entry expression is a zero-argument call to + /// a resolved named callable. Returns `None` if there is no entry + /// expression, the call has arguments, or the callee is not a simple item + /// reference. + /// + /// This is used in two places: + /// - **Codegen shortcut** (`prepare_codegen_entry_expr`): when the caller + /// passes an expression string that matches the existing entry point, we + /// reuse the already-compiled source package instead of recompiling. + /// - **Default entry fallback** (`compile_to_rir_with_debug_metadata`): + /// when no explicit entry expression is provided, this supplies the + /// `@EntryPoint` callable as the expression to compile. + fn entry_point_call_expr(&self) -> Option { + let source_package = self + .compiler + .package_store() + .get(map_fir_package_to_hir(self.source_package)) + .expect("source package should exist in the package store"); + let entry = source_package.package.entry.as_ref()?; + + let qsc_hir::hir::ExprKind::Call(callee, args) = &entry.kind else { + return None; + }; + let qsc_hir::hir::ExprKind::Tuple(items) = &args.kind else { + return None; + }; + if !items.is_empty() { + return None; + } + + let qsc_hir::hir::ExprKind::Var(qsc_hir::hir::Res::Item(item_id), _) = &callee.kind else { + return None; + }; + let item = source_package.package.items.get(item_id.item)?; + let qsc_hir::hir::ItemKind::Callable(callable) = &item.kind else { + return None; + }; + + let qualified_name = item + .parent + .and_then(|parent_id| source_package.package.items.get(parent_id)) + .and_then(|parent| match &parent.kind { + qsc_hir::hir::ItemKind::Namespace(namespace, _) => { + Some(namespace.name().to_string()) + } + _ => None, + }) + .map_or_else( + || callable.name.name.to_string(), + |namespace| format!("{namespace}.{}", callable.name.name), + ); + + Some(format!("{qualified_name}()")) + } + + /// Extracts an HIR `ItemId` from a runtime `Value::Global`. + /// + /// Maps the FIR-domain package and item IDs back to their HIR equivalents + /// for use with the HIR package store in codegen preparation. + /// + /// # Errors + /// + /// Returns `Error::NotACallable` if the value is not a `Value::Global`. + fn hir_item_id_from_value( + callable: &Value, + ) -> std::result::Result> { + let Value::Global(store_item_id, _) = callable else { + return Err(vec![Error::NotACallable]); + }; + + Ok(qsc_hir::hir::ItemId { + package: map_fir_package_to_hir(store_item_id.package), + item: map_fir_local_item_to_hir(store_item_id.item), + }) + } + + /// Normalizes a `StoreItemId` through the HIR↔FIR mapping round-trip. + /// + /// The interpreter's FIR store may use package/item IDs from a different + /// lowering pass than the freshly-lowered codegen store. Round-tripping + /// through `map_fir→hir→fir` ensures IDs align with the codegen store's + /// ID space. + fn remap_store_item_id_for_codegen(store_item_id: fir::StoreItemId) -> fir::StoreItemId { + fir::StoreItemId { + package: map_hir_package_to_fir(map_fir_package_to_hir(store_item_id.package)), + item: map_hir_local_item_to_fir(map_fir_local_item_to_hir(store_item_id.item)), + } + } + + /// Recursively remaps all `StoreItemId` references within a runtime `Value` + /// to the codegen FIR store's ID space. + /// + /// Applies `remap_store_item_id_for_codegen` to every callable reference + /// (`Global`, `Closure`, and UDT-tagged `Tuple`) so the value tree is + /// compatible with the freshly-lowered codegen package store. + fn remap_value_for_codegen(value: Value) -> Value { + match value { + Value::Array(values) => Value::Array(Rc::new( + values + .iter() + .cloned() + .map(Self::remap_value_for_codegen) + .collect(), + )), + Value::Closure(inner) => Value::Closure(Box::new(Closure { + fixed_args: inner + .fixed_args + .iter() + .cloned() + .map(Self::remap_value_for_codegen) + .collect::>() + .into(), + id: Self::remap_store_item_id_for_codegen(inner.id), + functor: inner.functor, + })), + Value::Global(store_item_id, functor_app) => Value::Global( + Self::remap_store_item_id_for_codegen(store_item_id), + functor_app, + ), + Value::Tuple(values, store_item_id) => Value::Tuple( + values + .iter() + .cloned() + .map(Self::remap_value_for_codegen) + .collect::>() + .into(), + store_item_id.map(|id| Rc::new(Self::remap_store_item_id_for_codegen(*id))), + ), + other => other, + } + } + + fn partial_evaluation_error( + &self, + error: qsc_partial_eval::Error, + fallback_package: qsc_hir::hir::PackageId, + ) -> Vec { + let hir_package_id = match error.span() { + Some(span) => span.package, + None => fallback_package, + }; + let source_package = self + .compiler + .package_store() + .get(hir_package_id) + .expect("package should exist in the package store"); + vec![Error::PartialEvaluation(WithSource::from_map( + &source_package.sources, + error, + ))] + } + /// Performs QIR codegen using the given entry expression on a new instance of the environment /// and simulator but using the current compilation. pub fn qirgen(&mut self, expr: &str) -> std::result::Result> { @@ -915,48 +1103,16 @@ impl Interpreter { return Err(vec![Error::UnsupportedRuntimeCapabilities]); } - // Compile the expression. This operation will set the expression as - // the entry-point in the FIR store. - let (graph, compute_properties) = self.compile_entry_expr(expr)?; + let prepared_fir = self.prepare_codegen_entry_expr(expr)?; + let entry = entry_from_codegen_fir(&prepared_fir); + let CodegenFir { + fir_store, + fir_package_id, + compute_properties, + } = prepared_fir; - let Some(compute_properties) = compute_properties else { - // This can only happen if capability analysis was not run. This would be a bug - // and we are in a bad state and can't proceed. - panic!("internal error: compute properties not set after lowering entry expression"); - }; - let package = self.fir_store.get(self.package); - let entry = ProgramEntry { - exec_graph: graph, - expr: ( - self.package, - package - .entry - .expect("package must have an entry expression"), - ) - .into(), - }; - // Generate QIR - fir_to_qir( - &self.fir_store, - self.capabilities, - Some(compute_properties), - &entry, - ) - .map_err(|e| { - let hir_package_id = match e.span() { - Some(span) => span.package, - None => map_fir_package_to_hir(self.package), - }; - let source_package = self - .compiler - .package_store() - .get(hir_package_id) - .expect("package should exist in the package store"); - vec![Error::PartialEvaluation(WithSource::from_map( - &source_package.sources, - e, - ))] - }) + fir_to_qir(&fir_store, self.capabilities, &compute_properties, &entry) + .map_err(|e| self.partial_evaluation_error(e, map_fir_package_to_hir(fir_package_id))) } /// Performs QIR codegen using the given callable with the given arguments on a new instance of the environment @@ -970,32 +1126,32 @@ impl Interpreter { return Err(vec![Error::UnsupportedRuntimeCapabilities]); } - let Value::Global(store_item_id, _) = callable else { - return Err(vec![Error::NotACallable]); + let callable_id = Self::hir_item_id_from_value(callable)?; + let backend_args = Self::remap_value_for_codegen(args); + let prepared_fir = prepare_codegen_fir_from_callable_args( + self.compiler.package_store(), + callable_id, + &backend_args, + self.capabilities, + )?; + let backend_callable = fir::StoreItemId { + package: map_hir_package_to_fir(callable_id.package), + item: map_hir_local_item_to_fir(callable_id.item), }; + let CodegenFir { + fir_store, + compute_properties, + .. + } = prepared_fir; fir_to_qir_from_callable( - &self.fir_store, + &fir_store, self.capabilities, - None, - *store_item_id, - args, + &compute_properties, + backend_callable, + backend_args, ) - .map_err(|e| { - let hir_package_id = match e.span() { - Some(span) => span.package, - None => map_fir_package_to_hir(self.package), - }; - let source_package = self - .compiler - .package_store() - .get(hir_package_id) - .expect("package should exist in the package store"); - vec![Error::PartialEvaluation(WithSource::from_map( - &source_package.sources, - e, - ))] - }) + .map_err(|e| self.partial_evaluation_error(e, callable_id.package)) } /// Generates a circuit representation for the program. @@ -1100,12 +1256,12 @@ impl Interpreter { return Err(vec![Error::UnsupportedRuntimeCapabilities]); } - let program = self.compile_to_rir_with_debug_metadata(entry_expr)?; + let (program, fir_store) = self.compile_to_rir_with_debug_metadata(entry_expr)?; rir_to_circuit( &program, tracer_config, &[self.package, self.source_package], - &(self.compiler.package_store(), &self.fir_store), + &(self.compiler.package_store(), &fir_store), ) .map_err(|e| vec![e.into()]) } @@ -1120,41 +1276,41 @@ impl Interpreter { return Err(vec![Error::UnsupportedRuntimeCapabilities]); } - let Value::Global(store_item_id, _) = callable else { - return Err(vec![Error::NotACallable]); + let callable_id = Self::hir_item_id_from_value(callable)?; + let backend_args = Self::remap_value_for_codegen(args); + let prepared_fir = prepare_codegen_fir_from_callable_args( + self.compiler.package_store(), + callable_id, + &backend_args, + self.capabilities, + )?; + let backend_callable = fir::StoreItemId { + package: map_hir_package_to_fir(callable_id.package), + item: map_hir_local_item_to_fir(callable_id.item), }; + let CodegenFir { + fir_store, + compute_properties, + .. + } = prepared_fir; let (_original, transformed) = fir_to_rir_from_callable( - &self.fir_store, + &fir_store, self.capabilities, - None, - *store_item_id, - args, + &compute_properties, + backend_callable, + backend_args, PartialEvalConfig { generate_debug_metadata: true, }, ) - .map_err(|e| { - let hir_package_id = match e.span() { - Some(span) => span.package, - None => map_fir_package_to_hir(self.package), - }; - let source_package = self - .compiler - .package_store() - .get(hir_package_id) - .expect("package should exist in the package store"); - vec![Error::PartialEvaluation(WithSource::from_map( - &source_package.sources, - e, - ))] - })?; + .map_err(|e| self.partial_evaluation_error(e, callable_id.package))?; rir_to_circuit( &transformed, tracer_config, &[self.package, self.source_package], - &(self.compiler.package_store(), &self.fir_store), + &(self.compiler.package_store(), &fir_store), ) .map_err(|e| vec![e.into()]) } @@ -1162,74 +1318,38 @@ impl Interpreter { fn compile_to_rir_with_debug_metadata( &mut self, entry_expr: Option<&str>, - ) -> std::result::Result> { - let (entry, compute_properties) = if let Some(entry_expr) = &entry_expr { - // Compile the expression. This operation will set the expression as - // the entry-point in the FIR store. - let (graph, compute_properties) = self.compile_entry_expr(entry_expr)?; - - let Some(compute_properties) = compute_properties else { - // This can only happen if capability analysis was not run. - panic!( - "internal error: compute properties not set after lowering entry expression" - ); - }; - let package = self.fir_store.get(self.package); - let entry = ProgramEntry { - exec_graph: graph, - expr: ( - self.package, - package - .entry - .expect("package must have an entry expression"), + ) -> std::result::Result<(qsc_partial_eval::Program, qsc_fir::fir::PackageStore), Vec> + { + let (prepared_fir, fallback_package) = + if let Some(entry_expr) = entry_expr.or(self.entry_point_call_expr().as_deref()) { + ( + self.prepare_codegen_entry_expr(entry_expr)?, + map_fir_package_to_hir(self.package), ) - .into(), - }; - (entry, compute_properties) - } else { - let package = self.fir_store.get(self.source_package); - let entry = ProgramEntry { - exec_graph: package.entry_exec_graph.clone(), - expr: ( - self.source_package, - package - .entry - .expect("package must have an entry expression"), + } else { + ( + self.prepare_codegen_source_package()?, + map_fir_package_to_hir(self.source_package), ) - .into(), }; - ( - entry, - self.compute_properties.clone().expect( - "compute properties should be set if target profile isn't unrestricted", - ), - ) - }; + + let entry = entry_from_codegen_fir(&prepared_fir); + let CodegenFir { + fir_store, + compute_properties, + .. + } = prepared_fir; let (_original, transformed) = fir_to_rir( - &self.fir_store, + &fir_store, self.capabilities, - Some(compute_properties), + &compute_properties, &entry, PartialEvalConfig { generate_debug_metadata: true, }, ) - .map_err(|e| { - let hir_package_id = match e.span() { - Some(span) => span.package, - None => map_fir_package_to_hir(self.package), - }; - let source_package = self - .compiler - .package_store() - .get(hir_package_id) - .expect("package should exist in the package store"); - vec![Error::PartialEvaluation(WithSource::from_map( - &source_package.sources, - e, - ))] - })?; - Ok(transformed) + .map_err(|e| self.partial_evaluation_error(e, fallback_package))?; + Ok((transformed, fir_store)) } /// Sets the entry expression for the interpreter. @@ -1405,7 +1525,9 @@ impl Interpreter { } self.lower_and_update_package(unit_addition); - Ok((self.lowerer.take_exec_graph(), None)) + let graph = self.lowerer.take_exec_graph(); + self.fir_store.get_mut(self.package).entry_exec_graph = graph.clone(); + Ok((graph, None)) } fn lower_and_update_package(&mut self, unit: &qsc_frontend::incremental::Increment) { @@ -1446,6 +1568,7 @@ impl Interpreter { })?; let graph = self.lowerer.take_exec_graph(); + self.fir_store.get_mut(self.package).entry_exec_graph = graph.clone(); Ok((graph, Some(compute_properties))) } @@ -1664,7 +1787,7 @@ impl Debugger { self.position_encoding, ); collector.visit_package(package, &self.interpreter.fir_store); - let mut spans: Vec<_> = collector.statements.into_iter().collect(); + let mut spans: Vec<_> = collector.statements.into_values().collect(); // Sort by start position (line first, column next) spans.sort_by_key(|s| (s.range.start.line, s.range.start.column)); @@ -1748,7 +1871,7 @@ pub struct BreakpointSpan { } struct BreakpointCollector<'a> { - statements: FxHashSet, + statements: FxHashMap, sources: &'a SourceMap, offset: u32, package: &'a Package, @@ -1763,7 +1886,7 @@ impl<'a> BreakpointCollector<'a> { position_encoding: Encoding, ) -> Self { Self { - statements: FxHashSet::default(), + statements: FxHashMap::default(), sources, offset, package, @@ -1782,11 +1905,18 @@ impl<'a> BreakpointCollector<'a> { if source.offset == self.offset { let span = stmt.span - source.offset; if span != Span::default() { + let range = Range::from_span(self.position_encoding, &source.contents, &span); let bps = BreakpointSpan { id: stmt.id.into(), - range: Range::from_span(self.position_encoding, &source.contents, &span), + range, }; - self.statements.insert(bps); + // Keep the first statement seen for a source range so UI clients get + // one stable, hittable breakpoint per visual location. + // Multiple HIR passes (ReplaceQubitAllocation, LoopUni, + // conjugate_invert, spec_gen) generate statements sharing the same + // source span. The lowerer maps these 1:1 into FIR, so deduplication + // is needed here. + self.statements.entry(range).or_insert(bps); } } } diff --git a/source/compiler/qsc/src/interpret/circuit_tests.rs b/source/compiler/qsc/src/interpret/circuit_tests.rs index 7e9cb65fa2..6c4e8c5230 100644 --- a/source/compiler/qsc/src/interpret/circuit_tests.rs +++ b/source/compiler/qsc/src/interpret/circuit_tests.rs @@ -134,6 +134,40 @@ fn circuit_with_groups(code: &str, entry: CircuitEntryPoint) -> String { eval_circ.display_with_groups().to_string() } +/// Generates a grouped circuit with source locations disabled, asserts that +/// classical evaluation and static generation produce the same grouped display, +/// and returns the static rendering for snapshot comparison. +fn circuit_with_groups_without_source_locations(code: &str, entry: CircuitEntryPoint) -> String { + let eval_circ = circuit_with_options_success( + code, + Profile::Unrestricted, + entry.clone(), + CircuitGenerationMethod::ClassicalEval, + TracerConfig { + source_locations: false, + ..default_test_tracer_config() + }, + ); + + let static_circ = circuit_with_options_success( + code, + Profile::AdaptiveRIF, + entry, + CircuitGenerationMethod::Static, + TracerConfig { + source_locations: false, + ..default_test_tracer_config() + }, + ); + + assert_eq!( + eval_circ.display_with_groups().to_string(), + static_circ.display_with_groups().to_string() + ); + + static_circ.display_with_groups().to_string() +} + fn circuit_static(code: &str) -> Circuit { circuit_with_options_success( code, @@ -1609,6 +1643,263 @@ fn operation_declared_in_eval() { .assert_eq(&c.display_with_groups().to_string()); } +#[test] +fn static_entrypoint_handles_callable_returned_from_function() { + let circ = circuit_with_options_success( + r#" + namespace Test { + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + + function GetOp() : Qubit => Unit { + H + } + + @EntryPoint() + operation Main() : Unit { + use q = Qubit(); + ApplyOp(GetOp(), q); + } + } + "#, + Profile::AdaptiveRIF, + CircuitEntryPoint::EntryPoint, + CircuitGenerationMethod::Static, + TracerConfig { + source_locations: false, + ..default_test_tracer_config() + }, + ) + .to_string(); + + expect![[r#" + q_0 ── H ── + "#]] + .assert_eq(&circ); +} + +#[test] +fn grouped_scopes_use_source_name_for_specialized_direct_callables() { + let circ = circuit_with_groups_without_source_locations( + r#" + namespace Test { + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + + @EntryPoint() + operation Main() : Unit { + use q = Qubit(); + ApplyOp(H, q); + } + } + "#, + CircuitEntryPoint::EntryPoint, + ); + + expect![[r#" + q_0 ─ [ [Main] ─── [ [ApplyOp] ─── H ──── ] ──── ] ── + "#]] + .assert_eq(&circ); +} + +#[test] +fn grouped_scopes_use_source_name_for_specialized_callable_arrays() { + let circ = circuit_with_groups_without_source_locations( + r#" + namespace Test { + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + + @EntryPoint() + operation Main() : Unit { + use q = Qubit(); + let ops = [H, X]; + for op in ops { + ApplyOp(op, q); + } + } + } + "#, + CircuitEntryPoint::EntryPoint, + ); + + expect![[r#" + q_0 ─ [ [Main] ─── [ [loop: ops] ── [ [(1)] ── [ [ApplyOp] ─── H ──── ] ──── ] ─── [ [(2)] ── [ [ApplyOp] ─── X ──── ] ──── ] ──── ] ──── ] ── + "#]] + .assert_eq(&circ); +} + +#[test] +fn grouped_scopes_match_for_user_defined_adjoint_specialization() { + let circ = circuit_with_groups_without_source_locations( + r#" + namespace Test { + operation EncodeAsLogicalQubit(physicalQubit : Qubit, aux : Qubit[]) : Unit is Adj { + ApplyToEachA(CNOT(physicalQubit, _), aux); + } + + @EntryPoint() + operation Main() : Unit { + use logicalQubit = Qubit[3]; + EncodeAsLogicalQubit(logicalQubit[0], logicalQubit[1...]); + Adjoint EncodeAsLogicalQubit(logicalQubit[0], logicalQubit[1...]); + } + } + "#, + CircuitEntryPoint::EntryPoint, + ); + + expect![[r#" + q_0 ─ Main[1] ─ + ┆ + q_1 ─ Main[1] ─ + ┆ + q_2 ─ Main[1] ─ + + [1] Main: + q_0 ─ EncodeAsLogicalQubit[2] ── EncodeAsLogicalQubit'[3] ── + ┆ ┆ + q_1 ─ EncodeAsLogicalQubit[2] ── EncodeAsLogicalQubit'[3] ── + ┆ ┆ + q_2 ─ EncodeAsLogicalQubit[2] ── EncodeAsLogicalQubit'[3] ── + + [2] EncodeAsLogicalQubit: + q_0 ─ [4] ─ + ┆ + q_1 ─ [4] ─ + ┆ + q_2 ─ [4] ─ + + [3] EncodeAsLogicalQubit: + q_0 ─ '[5] ── + ┆ + q_1 ─ '[5] ── + ┆ + q_2 ─ '[5] ── + + [4] : + q_0 ── ● ──── ● ── + q_1 ── X ─────┼─── + q_2 ───────── X ── + + [5] : + q_0 ── ● ──── ● ── + q_1 ───┼───── X ── + q_2 ── X ───────── + "#]] + .assert_eq(&circ); +} + +#[test] +fn grouped_scopes_match_for_apply_operation_power_ca_lambda() { + let circ = circuit_with_groups_without_source_locations( + r#" + namespace Test { + operation U(q : Qubit) : Unit is Ctl + Adj { + Rz(Std.Math.PI() / 3.0, q); + } + + @EntryPoint() + operation Main() : Unit { + use state = Qubit(); + use phase = Qubit[2]; + let oracle = ApplyOperationPowerCA(_, qs => U(qs[0]), _); + ApplyQPE(oracle, [state], phase); + } + } + "#, + CircuitEntryPoint::EntryPoint, + ); + + expect![[r#" + q_0 ─ Main[1] ─ + ┆ + q_1 ─ Main[1] ─ + ┆ + q_2 ─ Main[1] ─ + + [1] Main: + q_0 ──────── U[2] ──────────────────────────────────────────────────────────────────── + ┆ + q_1 ── H ─── U[2] ──────── H ─────── Rz(-0.7854) ─── X ─── Rz(0.7854) ──── X ───────── + ┆ │ │ + q_2 ── H ─── U[2] ─── Rz(-0.7854) ────────────────── ● ─────────────────── ● ──── H ── + + [2] U: + q_0 ─ Rz(0.5236) ──── X ─── Rz(-0.5236) ─── X ─── Rz(0.5236) ──── X ─── Rz(-0.5236) ─── X ─── Rz(0.5236) ──── X ─── Rz(-0.5236) ─── X ── + q_1 ───────────────── ● ─────────────────── ● ─────────────────── ● ─────────────────── ● ────────────────────┼─────────────────────┼─── + q_2 ───────────────────────────────────────────────────────────────────────────────────────────────────────── ● ─────────────────── ● ── + "#]] + .assert_eq(&circ); +} + +#[test] +fn grouped_scopes_match_for_repeated_draw_random_bit_calls() { + let circ = circuit_with_groups_without_source_locations( + r#" + namespace Test { + operation DrawRandomBit() : Unit { + use q = Qubit(); + H(q); + MResetZ(q); + } + + @EntryPoint() + operation Main() : Unit { + DrawRandomBit(); + DrawRandomBit(); + } + } + "#, + CircuitEntryPoint::EntryPoint, + ); + + expect![[r#" + q_0 ─ Main[1] ─ + ╘═════ + ╘═════ + + [1] Main: + q_0 ─ DrawRandomBit[2] ─── DrawRandomBit[3] ── + ╘════════════════════┆══════════ + ╘══════════ + + [2] DrawRandomBit: + q_0 ── H ──── M ──── |0〉 ── + ╘════════════ + + + [3] DrawRandomBit: + q_0 ── H ──── M ──── |0〉 ── + │ + ╘════════════ + "#]] + .assert_eq(&circ); +} + +#[test] +fn static_entrypoint_handles_struct_copy_update() { + let circ = circuit_static( + r#" + namespace Test { + @EntryPoint() + operation Main() : Unit { + struct Point3d { X : Double, Y : Double, Z : Double } + + mutable point = new Point3d { X = 0.0, Y = 0.0, Z = 0.0 }; + point = new Point3d { ...point, X = point.X + 1.0 }; + let x = point.X; + } + } + "#, + ); + + expect![""].assert_eq(&circ.to_string()); +} + /// Tests that invoke circuit generation through the debugger. mod debugger_stepping { use super::Debugger; diff --git a/source/compiler/qsc/src/interpret/debugger_tests.rs b/source/compiler/qsc/src/interpret/debugger_tests.rs index c7dfd73c63..97f90f6ef4 100644 --- a/source/compiler/qsc/src/interpret/debugger_tests.rs +++ b/source/compiler/qsc/src/interpret/debugger_tests.rs @@ -117,9 +117,22 @@ mod given_debugger { p } }"#; + + static DUPLICATE_RANGE_SOURCE: &str = r#" + namespace Sample { + @EntryPoint() + operation Main() : Result[] { + use q1 = Qubit(); + Y(q1); + let m1 = M(q1); + return [m1]; + } + }"#; + #[cfg(test)] mod step { use qsc_data_structures::{source::SourceMap, target::TargetCapabilityFlags}; + use rustc_hash::FxHashSet; use super::*; @@ -238,5 +251,36 @@ mod given_debugger { expect_return(debugger, expected); Ok(()) } + + #[test] + fn duplicate_source_ranges_collapse_to_one_hittable_breakpoint() + -> Result<(), Vec> { + let sources = SourceMap::new([("test.qs".into(), DUPLICATE_RANGE_SOURCE.into())], None); + let (std_id, store) = + crate::compile::package_store_with_stdlib(TargetCapabilityFlags::all()); + let mut debugger = Debugger::new( + sources, + TargetCapabilityFlags::all(), + Encoding::Utf8, + LanguageFeatures::default(), + store, + &[(std_id, None)], + )?; + + let breakpoints = debugger.get_breakpoints("test.qs"); + assert_eq!(breakpoints.len(), 4); + + let unique_ranges: FxHashSet<_> = breakpoints.iter().map(|bp| bp.range).collect(); + assert_eq!(unique_ranges.len(), breakpoints.len()); + + let return_breakpoint_id = breakpoints + .last() + .expect("expected a return breakpoint") + .id + .into(); + + expect_bp(&mut debugger, &[return_breakpoint_id], return_breakpoint_id); + Ok(()) + } } } diff --git a/source/compiler/qsc/src/interpret/tests.rs b/source/compiler/qsc/src/interpret/tests.rs index 8c8cbe8077..98c7c360e4 100644 --- a/source/compiler/qsc/src/interpret/tests.rs +++ b/source/compiler/qsc/src/interpret/tests.rs @@ -1022,6 +1022,339 @@ mod given_interpreter { "#]].assert_eq(&res); } + fn assert_qir_has_three_h_gates(qir: &str) { + assert!( + qir.contains("define i64 @ENTRYPOINT__main()"), + "expected entry point in generated QIR, got:\n{qir}" + ); + assert!( + qir.contains(r#""required_num_qubits"="3""#), + "expected three qubits in generated QIR, got:\n{qir}" + ); + assert_eq!( + qir.matches("call void @__quantum__qis__h__body").count(), + 3, + "expected three H applications in generated QIR, got:\n{qir}" + ); + } + + fn user_global(interpreter: &Interpreter, name: &str) -> Value { + interpreter + .user_globals() + .into_iter() + .find_map(|(_, global_name, value)| (global_name.as_ref() == name).then_some(value)) + .unwrap_or_else(|| panic!("{name} should be present in user globals")) + } + + #[test] + fn qirgen_does_not_corrupt_later_interpreter_eval_or_recompilation() { + let mut interpreter = get_interpreter_with_capabilities(TargetCapabilityFlags::empty()); + let (result, output) = line( + &mut interpreter, + indoc! {"operation Foo() : Result { use q = Qubit(); let r = M(q); Reset(q); return r; } "}, + ); + is_only_value(&result, &output, &Value::unit()); + + interpreter.qirgen("Foo()").expect("expected success"); + + let (result, output) = line(&mut interpreter, "Foo()"); + is_only_value( + &result, + &output, + &Value::Result(qsc_eval::val::Result::Val(false)), + ); + + let (result, output) = line(&mut interpreter, "operation Bar() : Result { Foo() }"); + is_only_value(&result, &output, &Value::unit()); + let (result, output) = line(&mut interpreter, "Bar()"); + is_only_value( + &result, + &output, + &Value::Result(qsc_eval::val::Result::Val(false)), + ); + } + + #[test] + fn qirgen_from_callable_user_global_succeeds_after_fresh_lowering() { + let mut interpreter = get_interpreter_with_capabilities(TargetCapabilityFlags::empty()); + let (result, output) = line( + &mut interpreter, + indoc! {"operation Foo() : Result { use q = Qubit(); let r = M(q); Reset(q); return r; } "}, + ); + is_only_value(&result, &output, &Value::unit()); + + let callable = user_global(&interpreter, "Foo"); + + let res = interpreter + .qirgen_from_callable(&callable, Value::unit()) + .expect("expected success"); + + expect![[r#" + %Result = type opaque + %Qubit = type opaque + + @0 = internal constant [4 x i8] c"0_r\00" + + define i64 @ENTRYPOINT__main() #0 { + block_0: + call void @__quantum__rt__initialize(i8* null) + call void @__quantum__qis__cx__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Qubit* inttoptr (i64 1 to %Qubit*)) + call void @__quantum__qis__m__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Result* inttoptr (i64 0 to %Result*)) + call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 0 to %Result*), i8* getelementptr inbounds ([4 x i8], [4 x i8]* @0, i64 0, i64 0)) + ret i64 0 + } + + declare void @__quantum__rt__initialize(i8*) + + declare void @__quantum__qis__m__body(%Qubit*, %Result*) #1 + + declare void @__quantum__rt__result_record_output(%Result*, i8*) + + declare void @__quantum__qis__cx__body(%Qubit*, %Qubit*) + + attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="base_profile" "required_num_qubits"="2" "required_num_results"="1" } + attributes #1 = { "irreversible" } + + ; module flags + + !llvm.module.flags = !{!0, !1, !2, !3} + + !0 = !{i32 1, !"qir_major_version", i32 1} + !1 = !{i32 7, !"qir_minor_version", i32 0} + !2 = !{i32 1, !"dynamic_qubit_management", i1 false} + !3 = !{i32 1, !"dynamic_result_management", i1 false} + "#]] + .assert_eq(&res); + } + + #[test] + fn qirgen_from_callable_with_global_callable_arg_succeeds() { + let mut interpreter = get_interpreter_with_capabilities(TargetCapabilityFlags::empty()); + let (result, output) = line( + &mut interpreter, + indoc! {r#" + open Std.Canon; + + operation InvokeWithQubits(nQubits : Int, f : Qubit[] => Unit) : Unit { + use qs = Qubit[nQubits]; + f(qs); + } + + operation AllH(qs : Qubit[]) : Unit { + struct Point3d { X : Double, Y : Double, Z : Double } + + let point = new Point3d { X = 1.0, Y = 2.0, Z = 3.0 }; + let point2 = new Point3d { ...point, Z = 4.0 }; + let should_apply = point2.X == 1.0; + if should_apply { + ApplyToEach(H, qs); + } + } + + operation UnusedIntOutput() : Int { + 1 + } + "#}, + ); + is_only_value(&result, &output, &Value::unit()); + + let invoke_with_qubits = user_global(&interpreter, "InvokeWithQubits"); + let all_h = user_global(&interpreter, "AllH"); + + let qir = interpreter + .qirgen_from_callable( + &invoke_with_qubits, + Value::Tuple(vec![Value::Int(3), all_h].into(), None), + ) + .expect("expected success"); + + assert_qir_has_three_h_gates(&qir); + } + + #[test] + fn qirgen_from_callable_with_closure_arg_succeeds() { + let mut interpreter = get_interpreter_with_capabilities(TargetCapabilityFlags::empty()); + let (result, output) = line( + &mut interpreter, + indoc! {r#" + open Std.Canon; + + operation InvokeWithQubits(nQubits : Int, f : Qubit[] => Unit) : Unit { + use qs = Qubit[nQubits]; + f(qs); + } + "#}, + ); + is_only_value(&result, &output, &Value::unit()); + + let invoke_with_qubits = user_global(&interpreter, "InvokeWithQubits"); + + let (closure_result, closure_output) = line(&mut interpreter, "ApplyToEach(H, _)"); + assert!( + closure_output.is_empty(), + "unexpected output while creating closure: {closure_output}" + ); + let apply_h = closure_result.expect("expected closure value"); + + let qir = interpreter + .qirgen_from_callable( + &invoke_with_qubits, + Value::Tuple(vec![Value::Int(3), apply_h].into(), None), + ) + .expect("expected success"); + + assert_qir_has_three_h_gates(&qir); + } + + #[test] + fn qirgen_from_callable_with_arrow_input_reports_runtime_capability_errors() { + let mut interpreter = get_interpreter_with_capabilities( + TargetCapabilityFlags::Adaptive | TargetCapabilityFlags::IntegerComputations, + ); + let (result, output) = line( + &mut interpreter, + indoc! {r#" + import Std.Convert.*; + + operation InvokeWithMeasuredInt(f : (Int, Qubit) => Unit) : Unit { + use q = Qubit(); + let i = if MResetZ(q) == One { 1 } else { 0 }; + f(i, q); + } + + operation RotateByInt(i : Int, q : Qubit) : Unit { + Rx(IntAsDouble(i), q); + } + "#}, + ); + is_only_value(&result, &output, &Value::unit()); + + let invoke_with_measured_int = user_global(&interpreter, "InvokeWithMeasuredInt"); + let rotate_by_int = user_global(&interpreter, "RotateByInt"); + + let errors = interpreter + .qirgen_from_callable(&invoke_with_measured_int, rotate_by_int) + .expect_err("expected runtime capability error"); + + assert!( + errors + .iter() + .all(|error| matches!(error, crate::interpret::Error::PartialEvaluation(_))), + "expected deferred partial-evaluation capability errors, got {errors:?}" + ); + assert!( + errors + .iter() + .any(|error| format!("{error:?}").contains("UseOfDynamicDouble")), + "expected a dynamic double capability diagnostic, got {errors:?}" + ); + } + + #[test] + fn qirgen_from_callable_profile_incompatible_outputs_report_callable_scoped_errors() { + let mut interpreter = get_interpreter_with_capabilities(TargetCapabilityFlags::empty()); + let (result, output) = line( + &mut interpreter, + indoc! {r#" + operation ReturnInt() : Int { + 1 + } + + operation ReturnDouble() : Double { + 1.0 + } + + operation ReturnBool() : Bool { + true + } + + operation ReturnString() : String { + "hello" + } + "#}, + ); + is_only_value(&result, &output, &Value::unit()); + + let int_errors = interpreter + .qirgen_from_callable(&user_global(&interpreter, "ReturnInt"), Value::unit()) + .expect_err("expected integer output rejection"); + is_error( + &int_errors, + &expect![[r#" + cannot use an integer value as an output + [line_0] [ReturnInt] + "#]], + ); + + let double_errors = interpreter + .qirgen_from_callable(&user_global(&interpreter, "ReturnDouble"), Value::unit()) + .expect_err("expected double output rejection"); + is_error( + &double_errors, + &expect![[r#" + cannot use a double value as an output + [line_0] [ReturnDouble] + "#]], + ); + + let bool_errors = interpreter + .qirgen_from_callable(&user_global(&interpreter, "ReturnBool"), Value::unit()) + .expect_err("expected bool output rejection"); + is_error( + &bool_errors, + &expect![[r#" + cannot use a bool value as an output + [line_0] [ReturnBool] + "#]], + ); + + let advanced_errors = interpreter + .qirgen_from_callable(&user_global(&interpreter, "ReturnString"), Value::unit()) + .expect_err("expected advanced output rejection"); + is_error( + &advanced_errors, + &expect![[r#" + cannot use value with advanced type as an output + [line_0] [ReturnString] + "#]], + ); + } + + #[test] + fn qirgen_from_callable_does_not_corrupt_later_interpreter_eval_or_recompilation() { + let mut interpreter = get_interpreter_with_capabilities(TargetCapabilityFlags::empty()); + let (result, output) = line( + &mut interpreter, + indoc! {"operation Foo() : Result { use q = Qubit(); let r = M(q); Reset(q); return r; } "}, + ); + is_only_value(&result, &output, &Value::unit()); + + let callable = user_global(&interpreter, "Foo"); + + interpreter + .qirgen_from_callable(&callable, Value::unit()) + .expect("expected success"); + + let mut cursor = Cursor::new(Vec::::new()); + let mut receiver = CursorReceiver::new(&mut cursor); + let result = interpreter.invoke(&mut receiver, callable.clone(), Value::unit()); + let output = receiver.dump(); + is_only_value( + &result, + &output, + &Value::Result(qsc_eval::val::Result::Val(false)), + ); + + let (result, output) = line(&mut interpreter, "operation Bar() : Result { Foo() }"); + is_only_value(&result, &output, &Value::unit()); + let (result, output) = line(&mut interpreter, "Bar()"); + is_only_value( + &result, + &output, + &Value::Result(qsc_eval::val::Result::Val(false)), + ); + } + #[test] fn adaptive_qirgen() { let mut interpreter = get_interpreter_with_capabilities( @@ -1090,6 +1423,117 @@ mod given_interpreter { .assert_eq(&res); } + #[test] + fn adaptive_qirgen_source_entrypoint_uses_fresh_lowering() { + let mut interpreter = get_interpreter_with_capabilities( + TargetCapabilityFlags::Adaptive | TargetCapabilityFlags::IntegerComputations, + ); + let (result, output) = line( + &mut interpreter, + indoc! {r#" + namespace Test { + import Std.Intrinsic.*; + import Std.Math.*; + import Std.Measurement.*; + + @EntryPoint() + operation Main() : ((Result[], Int), Bool) { + use registerA = Qubit[3]; + if true { + X(registerA[0]); + if true { + X(registerA[1]); + if false { + X(registerA[2]); + } + } + } + let registerAMeasurements = MeasureEachZ(registerA); + + mutable a = 0; + if registerAMeasurements[0] == Zero { + if registerAMeasurements[1] == Zero and registerAMeasurements[2] == Zero { + set a = 0; + } elif registerAMeasurements[1] == Zero and registerAMeasurements[2] == One { + set a = 1; + } elif registerAMeasurements[1] == One and registerAMeasurements[2] == Zero { + set a = 2; + } else { + set a = 3; + } + } else { + if registerAMeasurements[1] == Zero and registerAMeasurements[2] == Zero { + set a = 4; + } elif registerAMeasurements[1] == Zero and registerAMeasurements[2] == One { + set a = 5; + } elif registerAMeasurements[1] == One and registerAMeasurements[2] == Zero { + set a = 6; + } else { + set a = 7; + } + } + ResetAll(registerA); + + use q = Qubit(); + ((registerAMeasurements, a), MResetZ(q) == One) + } + }"# + }, + ); + is_only_value(&result, &output, &Value::unit()); + + let qir = interpreter.qirgen("Test.Main()").expect("expected success"); + + assert!( + qir.contains("call void @__quantum__rt__int_record_output(i64 %var_"), + "expected dynamic integer output to be recorded from an SSA value, got:\n{qir}" + ); + assert!( + !qir.contains("call void @__quantum__rt__int_record_output(i64 0,"), + "expected source entrypoint QIR generation to avoid stale literal outputs, got:\n{qir}" + ); + } + + #[test] + fn adaptive_qirgen_source_entrypoint_supports_measurement_comparisons() { + let mut interpreter = get_interpreter_with_capabilities( + TargetCapabilityFlags::Adaptive | TargetCapabilityFlags::IntegerComputations, + ); + let (result, output) = line( + &mut interpreter, + indoc! {r#" + namespace Test { + import Std.Intrinsic.*; + + @EntryPoint() + operation Main() : (Bool, Bool, Bool, Bool) { + use (q0, q1) = (Qubit(), Qubit()); + X(q0); + CNOT(q0, q1); + let (r0, r1) = (M(q0), M(q1)); + Reset(q0); + Reset(q1); + return (r0 == One, r1 == Zero, r0 == r1, r0 == Zero ? false | true); + } + }"# + }, + ); + is_only_value(&result, &output, &Value::unit()); + + let qir = interpreter.qirgen("Test.Main()").expect("expected success"); + + assert!( + qir.contains( + "call i1 @__quantum__rt__read_result(%Result* inttoptr (i64 0 to %Result*))" + ), + "expected measurement comparisons to lower through read_result, got:\n{qir}" + ); + assert!( + qir.contains("icmp eq i1 %var_4, %var_5"), + "expected result-to-result equality to lower to an i1 comparison, got:\n{qir}" + ); + } + #[test] fn adaptive_qirgen_nested_output_types() { let mut interpreter = @@ -1233,6 +1677,26 @@ mod given_interpreter { "#]].assert_eq(&res); } + #[test] + fn adaptive_rif_qirgen_entry_expr_apply_to_each_sx() { + let mut interpreter = get_interpreter_with_capabilities( + TargetCapabilityFlags::Adaptive + | TargetCapabilityFlags::IntegerComputations + | TargetCapabilityFlags::FloatingPointComputations, + ); + let (result, output) = line(&mut interpreter, indoc! {"open Std.Canon;"}); + is_only_value(&result, &output, &Value::unit()); + + let res = interpreter + .qirgen("{ use qs = Qubit[4]; ApplyToEach(SX, qs); }") + .expect("expected success"); + + assert!( + res.contains("declare void @__quantum__qis__sx__body(%Qubit*)"), + "expected ApplyToEach(SX, qs) to generate SX calls, got:\n{res}" + ); + } + #[test] fn qirgen_entry_expr_defines_operation() { let mut interpreter = get_interpreter_with_capabilities(TargetCapabilityFlags::empty()); diff --git a/source/compiler/qsc/src/lib.rs b/source/compiler/qsc/src/lib.rs index 2c3194af02..636a4f4ed3 100644 --- a/source/compiler/qsc/src/lib.rs +++ b/source/compiler/qsc/src/lib.rs @@ -87,3 +87,7 @@ pub mod target { } pub mod openqasm; + +pub mod fir_transforms { + pub use qsc_fir_transforms::{defunctionalize, run_pipeline}; +} diff --git a/source/compiler/qsc/src/openqasm.rs b/source/compiler/qsc/src/openqasm.rs index 5ad57cd191..c375b95bef 100644 --- a/source/compiler/qsc/src/openqasm.rs +++ b/source/compiler/qsc/src/openqasm.rs @@ -5,6 +5,7 @@ use std::sync::Arc; use std::vec; use qsc_data_structures::error::WithSource; +use qsc_data_structures::target::Profile; use qsc_frontend::compile::PackageStore; use qsc_hir::hir::PackageId; use qsc_openqasm_compiler::compiler::parse_and_compile_to_qsharp_ast_with_config; @@ -56,7 +57,26 @@ pub struct CompileRawQasmResult( #[must_use] pub fn compile_openqasm(unit: QasmCompileUnit, package_type: PackageType) -> CompileRawQasmResult { - let (source_map, openqasm_errors, package, sig, profile) = unit.into_tuple(); + compile_openqasm_with_profile_override(unit, package_type, None) +} + +/// Compiles `OpenQASM` to Q# with optional explicit profile override. +/// +/// Profile precedence: +/// 1. `profile_override` (if provided) +/// 2. Pragma-derived profile from `OpenQASM` source +/// 3. Default to `Profile::Unrestricted` +/// +/// This enables cleaner profile management across `OpenQASM` compilation flows, +/// allowing callers to explicitly control the QIR profile used for circuit/QIR generation. +#[must_use] +pub fn compile_openqasm_with_profile_override( + unit: QasmCompileUnit, + package_type: PackageType, + profile_override: Option, +) -> CompileRawQasmResult { + let (source_map, openqasm_errors, package, sig, pragma_profile) = unit.into_tuple(); + let profile = profile_override.unwrap_or(pragma_profile.unwrap_or(Profile::Unrestricted)); let (stdid, mut store) = package_store_with_stdlib(profile.into()); let dependencies = vec![(PackageId::CORE, None), (stdid, None)]; diff --git a/source/compiler/qsc_circuit/src/builder.rs b/source/compiler/qsc_circuit/src/builder.rs index 806a4c457f..1dc79ed5e3 100644 --- a/source/compiler/qsc_circuit/src/builder.rs +++ b/source/compiler/qsc_circuit/src/builder.rs @@ -15,6 +15,7 @@ use qsc_data_structures::{ functors::FunctorApp, index_map::IndexMap, line_column::{Encoding, Position}, + span::Span, }; use qsc_eval::{ backend::Tracer, @@ -25,7 +26,7 @@ use qsc_fir::fir::{ self, ExprId, ExprKind, PackageId, PackageLookup, PackageStoreLookup, StoreItemId, }; use qsc_frontend::compile::{self}; -use qsc_lowerer::map_fir_package_to_hir; +use qsc_lowerer::{map_fir_package_to_hir, map_hir_package_to_fir}; use rustc_hash::{FxHashMap, FxHashSet}; #[cfg(test)] use std::fmt::Display; @@ -264,6 +265,7 @@ impl CircuitTracer { operations, qubits, self.config.group_by_scope, + &self.user_package_ids, ) } @@ -497,17 +499,21 @@ impl CircuitTracer { } } -/// Take a sequence of operations and build the final `Circuit`. -/// Operations are laid out into columns. Unnecessary groups are removed. -/// Source location metadata is resolved into displayable file/line/column information. +/// Constructs the final circuit representation from operations and qubits. +/// +/// This function: +/// - Optionally collapses unnecessary scope groups based on user/library package origin +/// - Lays out operations into columns for circuit visualization +/// - Resolves source location metadata into displayable file/line/column information pub(crate) fn finish_circuit( source_lookup: &impl SourceLookup, mut operations: Vec, qubits: Vec, collapse_trivial_groups: bool, + user_package_ids: &[PackageId], ) -> Circuit { if collapse_trivial_groups { - collapse_unnecessary_scopes(&mut operations, source_lookup); + collapse_unnecessary_scopes(&mut operations, source_lookup, user_package_ids); } let mut loop_id_cache = Default::default(); let operations = operations @@ -526,33 +532,137 @@ pub(crate) fn finish_circuit( /// An unnecessary loop scope is one that either has a single child iteration, /// or has multiple iterations that each operate on distinct sets of qubits (i.e. a "vertical" loop). /// An unnecessary lambda scope is one where the lambda has a single child operation. +/// Recursively collapses unnecessary scope groups and merges equivalent adjacent groups. +/// +/// An operation/group is considered unnecessary if: +/// - It's a loop scope with a single child iteration +/// - It's a loop scope where all iterations operate on disjoint qubit sets ("vertical" loop) +/// - It's a lambda scope with a single child, and is not a partial lambda from `ApplyToEach` +/// - It's a synthesized callable scope from a non-user package +/// +/// After collapsing, adjacent groups with equivalent scopes tied to synthesized callable +/// ancestry are merged to further reduce noise in the circuit display. fn collapse_unnecessary_scopes( operations: &mut Vec, source_lookup: &impl SourceLookup, + user_package_ids: &[PackageId], ) { let mut ops = vec![]; for mut op in operations.drain(..) { match &mut op.kind { OperationOrGroupKind::Single => {} OperationOrGroupKind::Group { children, .. } => { - collapse_unnecessary_scopes(children, source_lookup); + collapse_unnecessary_scopes(children, source_lookup, user_package_ids); } } - if let Some(children) = collapse_if_unnecessary(&mut op, source_lookup) { + if let Some(children) = collapse_if_unnecessary(&mut op, source_lookup, user_package_ids) { ops.extend(children); } else { ops.push(op); } } + merge_adjacent_equivalent_groups(&mut ops, source_lookup); *operations = ops; } -/// If the given operation or group is an outer scope that can be collapsed, -/// returns its children operations or groups. +/// Merges adjacent operation groups that are equivalent and share synthesized callable ancestry. +/// +/// Groups are merged when they: +/// - Have the same current lexical scope +/// - Share a synthesized callable ancestor in their scope stack (indicating they stem from +/// synthetic transformations like specialization or closure wrapping) +/// +/// This reduces visual clutter by consolidating synthetic groupings that represent +/// the same logical scope applied to different iterations or cases. +fn merge_adjacent_equivalent_groups( + operations: &mut Vec, + source_lookup: &impl SourceLookup, +) { + let mut merged = Vec::with_capacity(operations.len()); + + for mut op in operations.drain(..) { + if let Some(last) = merged.last_mut() + && can_merge_equivalent_group(last, &op, source_lookup) + { + merge_equivalent_group(last, &mut op); + continue; + } + + merged.push(op); + } + + *operations = merged; +} + +/// Determines whether two adjacent groups can be merged. +/// +/// Groups can merge if they have the same lexical scope AND at least one has a +/// synthesized callable ancestor, indicating they are synthetic variations of the same scope. +fn can_merge_equivalent_group( + last: &OperationOrGroup, + next: &OperationOrGroup, + source_lookup: &impl SourceLookup, +) -> bool { + matches!( + (last.scope_stack_if_group(), next.scope_stack_if_group()), + (Some(last_scope_stack), Some(next_scope_stack)) + if last_scope_stack.current_lexical_scope() == next_scope_stack.current_lexical_scope() + && (has_synthesized_callable_ancestor(last_scope_stack, source_lookup) + || has_synthesized_callable_ancestor(next_scope_stack, source_lookup)) + ) +} + +/// Checks whether a scope stack has a synthesized callable ancestor. +/// +/// Synthesized callables arise from compiler transformations like specialization, functor +/// application, or closure wrapping. A scope has a synthesized ancestor if any callable +/// in its caller chain is marked as synthesized. +fn has_synthesized_callable_ancestor( + scope_stack: &ScopeStack, + source_lookup: &impl SourceLookup, +) -> bool { + scope_stack.caller().0.iter().any(|entry| { + matches!(entry.lexical_scope(), Scope::Callable(..)) + && source_lookup.is_synthesized_callable_scope(entry.lexical_scope()) + }) +} + +/// Merges the next group into the last group by combining their child operations. +/// +/// Propagates inputs from next into last, then appends all child operations from next +/// to last's children, consolidating them into a single group. +fn merge_equivalent_group(last: &mut OperationOrGroup, next: &mut OperationOrGroup) { + last.merge_inputs(next); + + let next_children = match &mut next.kind { + OperationOrGroupKind::Group { children, .. } => take(children), + OperationOrGroupKind::Single => Vec::new(), + }; + + let last_children = match &mut last.kind { + OperationOrGroupKind::Group { children, .. } => children, + OperationOrGroupKind::Single => { + unreachable!("can_merge_equivalent_group only matches groups") + } + }; + + last_children.extend(next_children); +} + +/// Determines whether a scope group should be collapsed and returns its flattened children. +/// +/// Returns `Some(children)` if the group is unnecessary and can be safely removed; +/// `None` if the group should be preserved. +/// +/// Collapse rules: +/// - **Loop scopes**: Collapse if from non-user package, has single child, or operates on disjoint qubits +/// - **Lambda scopes**: Collapse if single child and not a partial lambda from `ApplyToEach` +/// - **Synthesized callables**: Collapse based on origin package and synthetic status fn collapse_if_unnecessary( op: &mut OperationOrGroup, source_lookup: &impl SourceLookup, + user_package_ids: &[PackageId], ) -> Option> { if let OperationOrGroupKind::Group { scope_stack, @@ -560,6 +670,12 @@ fn collapse_if_unnecessary( } = &mut op.kind { if let Scope::Loop(..) = scope_stack.current_lexical_scope() { + let scope = source_lookup + .resolve_scope(scope_stack.current_lexical_scope(), &mut Default::default()); + if should_collapse_non_user_loop_scope(&scope, user_package_ids) { + return Some(flatten_loop_iteration_children(children)); + } + if children.len() == 1 { // remove the loop scope let mut only_child = children.remove(0); @@ -586,21 +702,143 @@ fn collapse_if_unnecessary( all_children.extend(take(children)); } return Some(all_children); - } else if let Scope::Callable(..) = scope_stack.current_lexical_scope() - && children.len() == 1 - && source_lookup - .resolve_scope(scope_stack.current_lexical_scope(), &mut Default::default()) - .name - .as_ref() - == "" - { - // remove the lambda scope - return Some(take(children)); + } else if let Scope::Callable(..) = scope_stack.current_lexical_scope() { + let scope = source_lookup + .resolve_scope(scope_stack.current_lexical_scope(), &mut Default::default()); + if children.len() == 1 + && scope.name.as_ref() == "" + && !should_preserve_apply_to_each_partial_lambda( + source_lookup, + scope_stack, + user_package_ids, + ) + { + // remove the lambda scope + return Some(take(children)); + } + + if should_collapse_synthesized_callable_scope( + source_lookup, + scope_stack.current_lexical_scope(), + user_package_ids, + ) { + return Some(take(children)); + } } } None } +/// Determines whether a lambda scope should be preserved to maintain `ApplyToEach` structure. +/// +/// Preserves lambda scopes that are: +/// - Partial lambdas created within `ApplyToEach` closures +/// - Called from user code (not synthesized) +/// +/// This ensures that higher-order loop patterns like `ApplyToEach(op, qubits)` remain +/// readable in circuit displays rather than being flattened away. +fn should_preserve_apply_to_each_partial_lambda( + source_lookup: &impl SourceLookup, + scope_stack: &ScopeStack, + user_package_ids: &[PackageId], +) -> bool { + let mut loop_id_cache = Default::default(); + let mut saw_apply_to_each_closure = false; + + for caller in scope_stack.caller().0.iter().rev() { + let scope = caller.lexical_scope(); + + if matches!(scope, Scope::Loop(..) | Scope::LoopIteration(..)) { + continue; + } + + let Scope::Callable(..) = scope else { + return false; + }; + + let resolved_scope = source_lookup.resolve_scope(scope, &mut loop_id_cache); + + if source_lookup.is_synthesized_callable_scope(scope) + && resolved_scope.name.as_ref().starts_with("ApplyToEach") + { + saw_apply_to_each_closure = true; + continue; + } + + if saw_apply_to_each_closure { + return source_lookup + .callable_scope_origin_package(scope) + .is_some_and(|package_id| user_package_ids.contains(&package_id)) + && !source_lookup.is_synthesized_callable_scope(scope); + } + + return false; + } + + false +} + +/// Determines whether a loop scope originates from library code and should be collapsed. +/// +/// Library loops (those from non-user packages or generic synthetic loop markers) are +/// collapsed to reduce clutter. User-authored loops are preserved. +fn should_collapse_non_user_loop_scope( + scope: &LexicalScope, + user_package_ids: &[PackageId], +) -> bool { + scope.name.as_ref() == "loop: " + || scope + .location + .is_some_and(|location| !user_package_ids.contains(&location.package_id)) +} + +/// Flattens loop iteration groups, extracting their children. +/// +/// When a loop scope is collapsed, its loop-iteration child groups are unwrapped, +/// promoting their operations to the parent level for a cleaner circuit structure. +fn flatten_loop_iteration_children(children: &mut Vec) -> Vec { + let mut flattened = Vec::new(); + + for mut child in children.drain(..) { + match &mut child.kind { + OperationOrGroupKind::Group { + scope_stack, + children, + } if matches!( + scope_stack.current_lexical_scope(), + Scope::LoopIteration(..) + ) => + { + flattened.extend(take(children)); + } + OperationOrGroupKind::Single | OperationOrGroupKind::Group { .. } => { + flattened.push(child); + } + } + } + + flattened +} + +/// Determines whether a synthesized callable scope should be collapsed. +/// +/// Synthesized callables from library packages are collapsed to reduce visualization noise. +/// Synthesized callables from user packages may be retained to preserve semantic intent. +fn should_collapse_synthesized_callable_scope( + source_lookup: &impl SourceLookup, + scope: &Scope, + user_package_ids: &[PackageId], +) -> bool { + if !source_lookup.is_synthesized_callable_scope(scope) { + return false; + } + + match source_lookup.callable_scope_origin_package(scope) { + Some(package_id) => !user_package_ids.contains(&package_id), + None => true, + } +} + /// Cache for mapping loop source locations to their corresponding package and expression IDs. /// This information is repeatedly looked up when resolving loop scopes from RIR debug metadata, /// so caching it avoids expensive lookups in the FIR package store. @@ -616,6 +854,15 @@ pub trait SourceLookup { location: LogicalStackEntryLocation, loop_id_cache: &mut LoopIdCache, ) -> Option; + /// Returns whether a callable scope was synthesized during lowering rather + /// than originating from a user-declared HIR item. + /// + /// Circuit rendering uses this to collapse bookkeeping-only callable + /// scopes so they do not appear as separate groups in the final diagram. + fn is_synthesized_callable_scope(&self, scope: &Scope) -> bool; + /// Returns the package where the callable originally came from, when it + /// can be recovered from the callable scope's source metadata. + fn callable_scope_origin_package(&self, scope: &Scope) -> Option; } impl SourceLookup for (&compile::PackageStore, &fir::PackageStore) { @@ -659,7 +906,7 @@ impl SourceLookup for (&compile::PackageStore, &fir::PackageStore) { package_id: store_item_id.package, offset: scope_offset, }), - name: callable_decl.name.name.clone(), + name: displayable_callable_scope_name(&callable_decl.name.name), is_adjoint: functor_app.adjoint, is_classically_controlled: false, } @@ -668,12 +915,12 @@ impl SourceLookup for (&compile::PackageStore, &fir::PackageStore) { // trim the trailing dagger symbol and set `is_adjoint` accordingly let (name, is_adjoint) = if let Some(pos) = name.rfind('\'') { if pos == name.len() - 1 { - (name[..pos].to_string().into(), true) + (displayable_callable_scope_name(&name[..pos]), true) } else { - (name.clone(), false) + (displayable_callable_scope_name(name), false) } } else { - (name.clone(), false) + (displayable_callable_scope_name(name), false) }; LexicalScope { location: Some(*package_offset), @@ -692,11 +939,7 @@ impl SourceLookup for (&compile::PackageStore, &fir::PackageStore) { .0 .get(map_fir_package_to_hir(package_id)) .and_then(|p| p.sources.find_by_offset(cond_expr.span.lo)) - .map(|s| { - s.contents[(cond_expr.span.lo - s.offset) as usize - ..(cond_expr.span.hi - s.offset) as usize] - .to_string() - }); + .and_then(|s| source_span_contents(&s.contents, s.offset, cond_expr.span)); LexicalScope { name: format!("loop: {}", expr_contents.unwrap_or_default()).into(), @@ -810,6 +1053,129 @@ impl SourceLookup for (&compile::PackageStore, &fir::PackageStore) { } } } + + /// Treat FIR callables with no corresponding HIR item as synthesized + /// lowering artifacts, such as specialized helper scopes. + fn is_synthesized_callable_scope(&self, scope: &Scope) -> bool { + let Some((current_package, offset, name)) = callable_scope_origin_key(self.1, scope) else { + return false; + }; + + let Some(unit) = self.0.get(map_fir_package_to_hir(current_package)) else { + return false; + }; + + match scope { + Scope::Callable(CallableId::Id(store_item_id, _)) => { + if !unit + .package + .items + .contains_key(qsc_hir::hir::LocalItemId::from(usize::from( + store_item_id.item, + ))) + { + return true; + } + } + Scope::Callable(CallableId::Source(..)) => {} + Scope::Top + | Scope::Loop(..) + | Scope::LoopIteration(..) + | Scope::ClassicallyControlled { .. } => return false, + } + + !hir_package_contains_callable_origin(unit, offset, name.as_ref()) + } + + fn callable_scope_origin_package(&self, scope: &Scope) -> Option { + let (current_package, offset, name) = callable_scope_origin_key(self.1, scope)?; + + let current_match = self + .0 + .get(map_fir_package_to_hir(current_package)) + .and_then(|unit| { + hir_package_contains_callable_origin(unit, offset, name.as_ref()) + .then_some(current_package) + }); + + current_match.or_else(|| { + self.0.iter().find_map(|(hir_package_id, unit)| { + hir_package_contains_callable_origin(unit, offset, name.as_ref()) + .then_some(map_hir_package_to_fir(hir_package_id)) + }) + }) + } +} + +fn callable_scope_origin_key( + fir_store: &fir::PackageStore, + scope: &Scope, +) -> Option<(PackageId, u32, Rc)> { + match scope { + Scope::Callable(CallableId::Id(store_item_id, _)) => { + let item = fir_store.get_item(*store_item_id); + let fir::ItemKind::Callable(callable_decl) = &item.kind else { + return None; + }; + + Some(( + store_item_id.package, + callable_decl.span.lo, + displayable_callable_scope_name(&callable_decl.name.name), + )) + } + Scope::Callable(CallableId::Source(package_offset, name)) => Some(( + package_offset.package_id, + package_offset.offset, + source_callable_origin_name(name), + )), + Scope::Top + | Scope::Loop(..) + | Scope::LoopIteration(..) + | Scope::ClassicallyControlled { .. } => None, + } +} + +fn source_callable_origin_name(name: &str) -> Rc { + if let Some(stripped) = name.strip_suffix('\'') { + displayable_callable_scope_name(stripped) + } else { + displayable_callable_scope_name(name) + } +} + +fn hir_package_contains_callable_origin( + unit: &compile::CompileUnit, + offset: u32, + name: &str, +) -> bool { + unit.package.items.values().any(|item| { + let qsc_hir::hir::ItemKind::Callable(decl) = &item.kind else { + return false; + }; + + decl.span.lo == offset && displayable_callable_scope_name(&decl.name.name).as_ref() == name + }) +} + +fn source_span_contents(contents: &str, source_offset: u32, span: Span) -> Option { + let start = usize::try_from(span.lo.checked_sub(source_offset)?).ok()?; + let end = usize::try_from(span.hi.checked_sub(source_offset)?).ok()?; + contents.get(start..end).map(ToString::to_string) +} + +fn displayable_callable_scope_name(name: &str) -> Rc { + if name.starts_with("") { + return name.into(); + } + + let suffix_start = match (name.find('<'), name.find('{')) { + (Some(functor_suffix), Some(callable_suffix)) => functor_suffix.min(callable_suffix), + (Some(functor_suffix), None) => functor_suffix, + (None, Some(callable_suffix)) => callable_suffix, + (None, None) => name.len(), + }; + name[..suffix_start].into() } fn callable_scope_offset(callable_decl: &fir::CallableDecl, functor_app: FunctorApp) -> u32 { diff --git a/source/compiler/qsc_circuit/src/builder/tests.rs b/source/compiler/qsc_circuit/src/builder/tests.rs index 54cd00e5c0..f2870cd434 100644 --- a/source/compiler/qsc_circuit/src/builder/tests.rs +++ b/source/compiler/qsc_circuit/src/builder/tests.rs @@ -7,13 +7,17 @@ mod group_scopes; mod logical_stack_trace; mod prune_classical_qubits; -use std::vec; +use std::{rc::Rc, vec}; use super::*; use expect_test::expect; +use indoc::indoc; use qsc_data_structures::{functors::FunctorApp, span::Span}; use qsc_eval::debug::Frame; -use qsc_fir::fir::StoreItemId; +use qsc_fir::fir::{self, ExprKind, PackageLookup, StoreItemId}; +use qsc_frontend::compile::{self, PackageStore, compile}; +use qsc_lowerer::map_hir_package_to_fir; +use qsc_passes::{PackageType, run_core_passes, run_default_passes}; use rustc_hash::FxHashMap; #[derive(Default)] @@ -68,6 +72,23 @@ impl SourceLookup for FakeCompilation { _ => panic!("only Call and Branch locations are supported in tests"), } } + + fn is_synthesized_callable_scope(&self, _scope: &Scope) -> bool { + false + } + + fn callable_scope_origin_package(&self, scope: &Scope) -> Option { + match scope { + Scope::Callable(CallableId::Id(store_item_id, _)) => Some(store_item_id.package), + Scope::Callable(CallableId::Source(package_offset, _)) => { + Some(package_offset.package_id) + } + Scope::Top + | Scope::Loop(..) + | Scope::LoopIteration(..) + | Scope::ClassicallyControlled { .. } => None, + } + } } impl FakeCompilation { @@ -149,6 +170,211 @@ impl Scopes { } } +/// Builds matching HIR and FIR package stores with core, dependency, and user +/// packages for callable-origin lookup tests. +fn compile_origin_lookup_stores() -> (PackageStore, fir::PackageStore, PackageId, PackageId) { + let mut fir_lowerer = qsc_lowerer::Lowerer::new(); + + let mut core = compile::core(); + run_core_passes(&mut core); + + let lowering_store = fir::PackageStore::new(); + let core_fir = fir_lowerer.lower_package(&core.package, &lowering_store); + let mut store = PackageStore::new(core); + + let library_source = indoc! { + r#" + namespace Library { + operation LibraryHelper() : Unit { } + } + "# + }; + let mut library_unit = compile( + &store, + &[], + qsc_data_structures::source::SourceMap::new( + [("Library.qs".into(), library_source.into())], + None, + ), + qsc_data_structures::target::TargetCapabilityFlags::all(), + qsc_data_structures::language_features::LanguageFeatures::default(), + ); + assert!(library_unit.errors.is_empty(), "{:?}", library_unit.errors); + let library_pass_errors = run_default_passes(store.core(), &mut library_unit, PackageType::Lib); + assert!(library_pass_errors.is_empty(), "{library_pass_errors:?}"); + let library_fir = fir_lowerer.lower_package(&library_unit.package, &lowering_store); + let dep_unit_id = store.insert(library_unit); + let dep_pkg_id = map_hir_package_to_fir(dep_unit_id); + + let user_source = indoc! { + r#" + namespace User { + operation UserHelper() : Unit { } + } + "# + }; + let mut user_unit = compile( + &store, + &[], + qsc_data_structures::source::SourceMap::new([("User.qs".into(), user_source.into())], None), + qsc_data_structures::target::TargetCapabilityFlags::all(), + qsc_data_structures::language_features::LanguageFeatures::default(), + ); + assert!(user_unit.errors.is_empty(), "{:?}", user_unit.errors); + let user_pass_errors = run_default_passes(store.core(), &mut user_unit, PackageType::Lib); + assert!(user_pass_errors.is_empty(), "{user_pass_errors:?}"); + let user_fir = fir_lowerer.lower_package(&user_unit.package, &lowering_store); + let app_unit_id = store.insert(user_unit); + let app_pkg_id = map_hir_package_to_fir(app_unit_id); + + let mut fir_store = fir::PackageStore::new(); + fir_store.insert( + map_hir_package_to_fir(qsc_hir::hir::PackageId::CORE), + core_fir, + ); + fir_store.insert(dep_pkg_id, library_fir); + fir_store.insert(app_pkg_id, user_fir); + + (store, fir_store, dep_pkg_id, app_pkg_id) +} + +/// Copies a named FIR callable into another package while preserving its source +/// span, matching synthesized callables that keep their original source origin. +fn clone_callable_into_package( + fir_store: &mut fir::PackageStore, + source_package: PackageId, + target_package: PackageId, + source_name: &str, + suffix: &str, +) -> StoreItemId { + let source_item = fir_store + .get(source_package) + .items + .iter() + .find_map(|(item_id, item)| match &item.kind { + fir::ItemKind::Callable(decl) if decl.name.name.as_ref() == source_name => { + Some((item_id, item.clone())) + } + _ => None, + }) + .expect("expected callable in source package") + .1; + + let target = fir_store.get_mut(target_package); + let new_item_id = target + .items + .iter() + .map(|(item_id, _)| usize::from(item_id)) + .max() + .map_or(0, |max_id| max_id + 1) + .into(); + + let mut new_item = source_item; + new_item.id = new_item_id; + if let fir::ItemKind::Callable(decl) = &mut new_item.kind { + decl.name.name = Rc::from(format!("{}{suffix}", decl.name.name)); + } + target.items.insert(new_item_id, new_item); + + StoreItemId { + package: target_package, + item: new_item_id, + } +} + +/// Creates a source-location callable scope for the given FIR callable so tests +/// exercise origin resolution by package span instead of item id. +fn source_scope_for_callable(fir_store: &fir::PackageStore, callable_id: StoreItemId) -> Scope { + let callable = fir_store.get_item(callable_id); + let fir::ItemKind::Callable(decl) = &callable.kind else { + panic!("expected callable item"); + }; + + Scope::Callable(CallableId::Source( + PackageOffset { + package_id: callable_id.package, + offset: decl.span.lo, + }, + decl.name.name.clone(), + )) +} + +#[test] +fn synthesized_callable_scope_collapse_uses_origin_package() { + let (store, mut fir_store, library_package_id, user_package_id) = + compile_origin_lookup_stores(); + + // Move one dependency callable and one user callable into the user package + // with synthesized names while leaving their source spans intact. + let library_clone = clone_callable_into_package( + &mut fir_store, + library_package_id, + user_package_id, + "LibraryHelper", + "", + ); + let user_clone = clone_callable_into_package( + &mut fir_store, + user_package_id, + user_package_id, + "UserHelper", + "{H}", + ); + + // Check both callable id scopes and source scopes because each path can be + // used when deciding whether a synthesized group should collapse. + let library_id_scope = Scope::Callable(CallableId::Id(library_clone, FunctorApp::default())); + let user_id_scope = Scope::Callable(CallableId::Id(user_clone, FunctorApp::default())); + let library_source_scope = source_scope_for_callable(&fir_store, library_clone); + let user_source_scope = source_scope_for_callable(&fir_store, user_clone); + let lookup = (&store, &fir_store); + + assert!(lookup.is_synthesized_callable_scope(&library_id_scope)); + assert!(lookup.is_synthesized_callable_scope(&user_id_scope)); + assert!(lookup.is_synthesized_callable_scope(&library_source_scope)); + assert!(!lookup.is_synthesized_callable_scope(&user_source_scope)); + + // Only synthesized callables whose origin package is outside the rendered + // user package set should collapse into their parent circuit group. + assert_eq!( + lookup.callable_scope_origin_package(&library_id_scope), + Some(library_package_id) + ); + assert_eq!( + lookup.callable_scope_origin_package(&user_id_scope), + Some(user_package_id) + ); + assert_eq!( + lookup.callable_scope_origin_package(&library_source_scope), + Some(library_package_id) + ); + assert_eq!( + lookup.callable_scope_origin_package(&user_source_scope), + Some(user_package_id) + ); + + assert!(should_collapse_synthesized_callable_scope( + &lookup, + &library_id_scope, + &[user_package_id], + )); + assert!(!should_collapse_synthesized_callable_scope( + &lookup, + &user_id_scope, + &[user_package_id], + )); + assert!(should_collapse_synthesized_callable_scope( + &lookup, + &library_source_scope, + &[user_package_id], + )); + assert!(!should_collapse_synthesized_callable_scope( + &lookup, + &user_source_scope, + &[user_package_id], + )); +} + #[test] fn exceed_max_operations() { let mut builder = CircuitTracer::new( @@ -517,7 +743,6 @@ fn measurement_target_propagated_to_group() { .iter() .find(|reg| *reg == &measurement_op.qubits[0]) .expect("expected measurement qubit in group operation's targets"); - group_op .targets .iter() @@ -525,6 +750,97 @@ fn measurement_target_propagated_to_group() { .expect("expected measurement result in group operation's targets"); } +/// Verifies that loop scope resolution falls back to the loop expression when a +/// condition span cannot be mapped into the source package. +#[test] +fn resolve_scope_for_loop_tolerates_out_of_range_condition_span() { + let mut fir_lowerer = qsc_lowerer::Lowerer::new(); + let mut core = compile::core(); + run_core_passes(&mut core); + let lowering_store = fir::PackageStore::new(); + let core_fir = fir_lowerer.lower_package(&core.package, &lowering_store); + let mut store = PackageStore::new(core); + + let source = indoc! { + r#" + namespace Test { + operation Main() : Unit { + mutable i = 0; + while i < 2 { + set i += 1; + } + } + } + "# + }; + let mut unit = compile( + &store, + &[], + qsc_data_structures::source::SourceMap::new( + [("A.qs".into(), source.into())], + Some("Test.Main()".into()), + ), + qsc_data_structures::target::TargetCapabilityFlags::all(), + qsc_data_structures::language_features::LanguageFeatures::default(), + ); + assert!(unit.errors.is_empty(), "{:?}", unit.errors); + let pass_errors = run_default_passes(store.core(), &mut unit, PackageType::Lib); + assert!(pass_errors.is_empty(), "{pass_errors:?}"); + let unit_fir = fir_lowerer.lower_package(&unit.package, &lowering_store); + let hir_package_id = store.insert(unit); + let fir_package_id = map_hir_package_to_fir(hir_package_id); + + let mut fir_store = fir::PackageStore::new(); + fir_store.insert( + map_hir_package_to_fir(qsc_hir::hir::PackageId::CORE), + core_fir, + ); + fir_store.insert(fir_package_id, unit_fir); + + // Capture the while expression and its condition separately so only the + // condition span is corrupted. + let (loop_expr_id, cond_expr_id) = { + let package = fir_store.get(fir_package_id); + package + .exprs + .iter() + .find_map(|(expr_id, expr)| { + if let ExprKind::While(cond_expr_id, _) = expr.kind { + Some((expr_id, cond_expr_id)) + } else { + None + } + }) + .expect("expected while loop in lowered FIR") + }; + + // Simulate transform-produced FIR whose condition span points beyond the + // source file while the enclosing loop span remains valid. + let source_len = u32::try_from(source.len()).expect("source length should fit in u32"); + let cond_expr = fir_store + .get_mut(fir_package_id) + .exprs + .get_mut(cond_expr_id) + .expect("condition expr should exist"); + cond_expr.span.hi = source_len + 100; + + // Resolution should tolerate the bad condition span and still produce a + // stable group name and source location from the loop expression itself. + let scope = (&store, &fir_store).resolve_scope( + &Scope::Loop(LoopId::Id(fir_package_id, loop_expr_id)), + &mut Default::default(), + ); + + assert_eq!(scope.name.as_ref(), "loop: "); + assert_eq!( + scope.location, + Some(PackageOffset { + package_id: fir_package_id, + offset: fir_store.get(fir_package_id).get_expr(loop_expr_id).span.lo, + }) + ); +} + #[test] fn source_locations_for_groups() { let mut c = FakeCompilation::default(); diff --git a/source/compiler/qsc_circuit/src/operations.rs b/source/compiler/qsc_circuit/src/operations.rs index 23447fef4d..78ea919485 100644 --- a/source/compiler/qsc_circuit/src/operations.rs +++ b/source/compiler/qsc_circuit/src/operations.rs @@ -107,20 +107,12 @@ fn operation_circuit_entry_expr(operation_expr: &str, qubit_params: &[QubitParam let mut qs_start = 0; let mut call_args = vec![]; for q in qubit_params { - // Q# ranges are end-inclusive - let qs_end = qs_start + q.num_qubits() - 1; if q.dimensions == 0 { call_args.push(format!("qs[{qs_start}]")); } else { - // Array argument - use a range to index - let mut call_arg = format!("qs[{qs_start}..{qs_end}]"); - for _ in 1..q.dimensions { - // Chunk the array for multi-dimensional array arguments - call_arg = format!("Microsoft.Quantum.Arrays.Chunks({NUM_QUBITS}, {call_arg})"); - } - call_args.push(call_arg); + call_args.push(build_nested_qubit_array_arg(qs_start, q.dimensions)); } - qs_start = qs_end + 1; + qs_start += q.num_qubits(); } let call_args = call_args.join(", "); @@ -143,6 +135,28 @@ fn operation_circuit_entry_expr(operation_expr: &str, qubit_params: &[QubitParam /// in the operation arguments. const NUM_QUBITS: u32 = 2; +/// Constructs a nested qubit array argument for a circuit entry expression. +/// +/// Generates explicit array constructors for multi-dimensional qubit array parameters. +/// For example, a 2D qubit array parameter receives nested array syntax: `[[qs[0..1], qs[2..3]], [qs[4..5], qs[6..7]]]` +/// Recursively partitions the qubit range into `NUM_QUBITS` wide chunks at each dimension level. +fn build_nested_qubit_array_arg(start: u32, dimensions: u32) -> String { + debug_assert!(dimensions > 0, "array dimensions should be positive"); + + if dimensions == 1 { + let end = start + NUM_QUBITS - 1; + return format!("qs[{start}..{end}]"); + } + + let chunk_width = NUM_QUBITS.pow(dimensions - 1); + let chunks = (0..NUM_QUBITS) + .map(|chunk_index| { + build_nested_qubit_array_arg(start + chunk_index * chunk_width, dimensions - 1) + }) + .collect::>(); + format!("[{}]", chunks.join(", ")) +} + fn get_qubit_param_info(input: &Pat) -> Vec { match &input.ty { Ty::Prim(Prim::Qubit) => { diff --git a/source/compiler/qsc_circuit/src/operations/tests.rs b/source/compiler/qsc_circuit/src/operations/tests.rs index b2b788d7bd..ccb99a1667 100644 --- a/source/compiler/qsc_circuit/src/operations/tests.rs +++ b/source/compiler/qsc_circuit/src/operations/tests.rs @@ -133,7 +133,7 @@ fn qubit_params() { } #[test] -fn qubit_array_params() { +fn qubit_array_parameters_allocate_flat_register_slices() { let (item, operation) = compile_one_operation( r" namespace Test { @@ -149,7 +149,7 @@ fn qubit_array_params() { expect![[r" { use qs = Qubit[15]; - (Test.Test)(qs[0..1], Microsoft.Quantum.Arrays.Chunks(2, qs[2..5]), Microsoft.Quantum.Arrays.Chunks(2, Microsoft.Quantum.Arrays.Chunks(2, qs[6..13])), qs[14]); + (Test.Test)(qs[0..1], [qs[2..3], qs[4..5]], [[qs[6..7], qs[8..9]], [qs[10..11], qs[12..13]]], qs[14]); let r: Result[] = []; r }"]].assert_eq(&expr); diff --git a/source/compiler/qsc_circuit/src/rir_to_circuit.rs b/source/compiler/qsc_circuit/src/rir_to_circuit.rs index 148090127a..a2a810fb35 100644 --- a/source/compiler/qsc_circuit/src/rir_to_circuit.rs +++ b/source/compiler/qsc_circuit/src/rir_to_circuit.rs @@ -27,6 +27,10 @@ use crate::{ rir_to_circuit::control_flow::{StructuredControlFlow, reconstruct_control_flow}, }; +/// Converts a Runtime Intermediate Representation (RIR) program into a visual circuit. +/// +/// Traverses the RIR's structured control flow, collects quantum operations, tracks variable +/// assignments, and synthesizes the final circuit with scope grouping and qubit-wire mapping. pub fn rir_to_circuit( program_rir: &Program, config: TracerConfig, @@ -77,7 +81,13 @@ pub fn rir_to_circuit( // All operations from the program collected, finalize the circuit. let qubits = wire_map_builder.into_wire_map().to_qubits(source_lookup); let operations = builder.into_operations(); - let circuit = finish_circuit(source_lookup, operations, qubits, config.group_by_scope); + let circuit = finish_circuit( + source_lookup, + operations, + qubits, + config.group_by_scope, + user_package_ids, + ); Ok(circuit) } @@ -355,6 +365,7 @@ fn process_variables( | Instruction::Fsub(operand, operand1, variable) | Instruction::Fmul(operand, operand1, variable) | Instruction::Fdiv(operand, operand1, variable) + | Instruction::Frem(operand, operand1, variable) | Instruction::LogicalAnd(operand, operand1, variable) | Instruction::LogicalOr(operand, operand1, variable) | Instruction::BitwiseAnd(operand, operand1, variable) diff --git a/source/compiler/qsc_circuit/src/rir_to_circuit/tests/logical_stack_trace.rs b/source/compiler/qsc_circuit/src/rir_to_circuit/tests/logical_stack_trace.rs index b44281c00a..4da5be9dfb 100644 --- a/source/compiler/qsc_circuit/src/rir_to_circuit/tests/logical_stack_trace.rs +++ b/source/compiler/qsc_circuit/src/rir_to_circuit/tests/logical_stack_trace.rs @@ -150,11 +150,14 @@ fn check_trace(file: &str, expr: &str, expect: &Expect) { ) .into(), }; + let compute_properties = + qsc_passes::PassContext::run_fir_passes_on_fir(&fir_store, id, capabilities) + .expect("FIR passes should succeed"); let (_, rir) = fir_to_rir( &fir_store, capabilities, - None, + &compute_properties, &entry, PartialEvalConfig { generate_debug_metadata: true, diff --git a/source/compiler/qsc_codegen/src/qir.rs b/source/compiler/qsc_codegen/src/qir.rs index 5efb55f353..3dc23237e5 100644 --- a/source/compiler/qsc_codegen/src/qir.rs +++ b/source/compiler/qsc_codegen/src/qir.rs @@ -16,7 +16,7 @@ pub mod v2; pub fn fir_to_rir( fir_store: &qsc_fir::fir::PackageStore, capabilities: TargetCapabilityFlags, - compute_properties: Option, + compute_properties: &PackageStoreComputeProperties, entry: &ProgramEntry, partial_eval_config: PartialEvalConfig, ) -> Result<(Program, Program), qsc_partial_eval::Error> { @@ -36,7 +36,7 @@ pub fn fir_to_rir( pub fn fir_to_qir( fir_store: &qsc_fir::fir::PackageStore, capabilities: TargetCapabilityFlags, - compute_properties: Option, + compute_properties: &PackageStoreComputeProperties, entry: &ProgramEntry, ) -> Result { let mut program = get_rir_from_compilation( @@ -60,18 +60,13 @@ pub fn fir_to_qir( pub fn fir_to_qir_from_callable( fir_store: &qsc_fir::fir::PackageStore, capabilities: TargetCapabilityFlags, - compute_properties: Option, + compute_properties: &PackageStoreComputeProperties, callable: qsc_fir::fir::StoreItemId, args: Value, ) -> Result { - let compute_properties = compute_properties.unwrap_or_else(|| { - let analyzer = qsc_rca::Analyzer::init(fir_store, capabilities); - analyzer.analyze_all() - }); - let mut program = partially_evaluate_call( fir_store, - &compute_properties, + compute_properties, callable, args, capabilities, @@ -91,19 +86,14 @@ pub fn fir_to_qir_from_callable( pub fn fir_to_rir_from_callable( fir_store: &qsc_fir::fir::PackageStore, capabilities: TargetCapabilityFlags, - compute_properties: Option, + compute_properties: &PackageStoreComputeProperties, callable: qsc_fir::fir::StoreItemId, args: Value, partial_eval_config: PartialEvalConfig, ) -> Result<(Program, Program), qsc_partial_eval::Error> { - let compute_properties = compute_properties.unwrap_or_else(|| { - let analyzer = qsc_rca::Analyzer::init(fir_store, capabilities); - analyzer.analyze_all() - }); - let mut program = partially_evaluate_call( fir_store, - &compute_properties, + compute_properties, callable, args, capabilities, @@ -116,19 +106,14 @@ pub fn fir_to_rir_from_callable( fn get_rir_from_compilation( fir_store: &qsc_fir::fir::PackageStore, - compute_properties: Option, + compute_properties: &PackageStoreComputeProperties, entry: &ProgramEntry, capabilities: TargetCapabilityFlags, partial_eval_config: PartialEvalConfig, ) -> Result { - let compute_properties = compute_properties.unwrap_or_else(|| { - let analyzer = qsc_rca::Analyzer::init(fir_store, capabilities); - analyzer.analyze_all() - }); - partially_evaluate( fir_store, - &compute_properties, + compute_properties, entry, capabilities, partial_eval_config, diff --git a/source/compiler/qsc_codegen/src/qir/v1.rs b/source/compiler/qsc_codegen/src/qir/v1.rs index f2cc3040a0..a497e2c108 100644 --- a/source/compiler/qsc_codegen/src/qir/v1.rs +++ b/source/compiler/qsc_codegen/src/qir/v1.rs @@ -178,6 +178,9 @@ impl ToQir for rir::Instruction { rir::Instruction::Fdiv(lhs, rhs, variable) => { fbinop_to_qir("fdiv", lhs, rhs, *variable, program) } + rir::Instruction::Frem(lhs, rhs, variable) => { + fbinop_to_qir("frem", lhs, rhs, *variable, program) + } rir::Instruction::Fmul(lhs, rhs, variable) => { fbinop_to_qir("fmul", lhs, rhs, *variable, program) } diff --git a/source/compiler/qsc_codegen/src/qir/v2.rs b/source/compiler/qsc_codegen/src/qir/v2.rs index ba41f88dbf..ec07e24702 100644 --- a/source/compiler/qsc_codegen/src/qir/v2.rs +++ b/source/compiler/qsc_codegen/src/qir/v2.rs @@ -171,6 +171,9 @@ impl ToQir for rir::Instruction { rir::Instruction::Fdiv(lhs, rhs, variable) => { fbinop_to_qir("fdiv", lhs, rhs, *variable, program) } + rir::Instruction::Frem(lhs, rhs, variable) => { + fbinop_to_qir("frem", lhs, rhs, *variable, program) + } rir::Instruction::Fmul(lhs, rhs, variable) => { fbinop_to_qir("fmul", lhs, rhs, *variable, program) } diff --git a/source/compiler/qsc_data_structures/src/functors.rs b/source/compiler/qsc_data_structures/src/functors.rs index d712b67522..022192fc96 100644 --- a/source/compiler/qsc_data_structures/src/functors.rs +++ b/source/compiler/qsc_data_structures/src/functors.rs @@ -8,7 +8,7 @@ use std::{ }; /// A functor application. -#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)] +#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)] pub struct FunctorApp { /// An invocation is either adjoint or not, with each successive use of `Adjoint` functor switching /// between the two, so a bool is sufficient to track. diff --git a/source/compiler/qsc_eval/src/lib.rs b/source/compiler/qsc_eval/src/lib.rs index 093bf625b4..a3cfe6d11a 100644 --- a/source/compiler/qsc_eval/src/lib.rs +++ b/source/compiler/qsc_eval/src/lib.rs @@ -563,7 +563,7 @@ impl Env { } } -#[derive(Default)] +#[derive(Clone, Default)] struct Scope { bindings: IndexMap, frame_id: usize, @@ -1120,7 +1120,9 @@ impl State { Some(var) => { var.value.append_array(rhs); } - None => return Err(Error::UnboundName(self.to_global_span(lhs.span))), + None => { + return Err(Error::UnboundName(self.to_global_span(lhs.span))); + } }, _ => unreachable!("unassignable array update pattern should be disallowed by compiler"), } @@ -1200,6 +1202,7 @@ impl State { Ok(()) } + #[allow(clippy::too_many_lines)] fn eval_call( &mut self, env: &mut Env, @@ -1228,7 +1231,9 @@ impl State { self.set_val_register(arg); return Ok(()); } - None => return Err(Error::UnboundName(self.to_global_span(callable_span))), + None => { + return Err(Error::UnboundName(self.to_global_span(callable_span))); + } }; let callee_span = self.to_global_span(callee.span); @@ -1677,7 +1682,9 @@ impl State { Some(var) => { var.value = rhs; } - None => return Err(Error::UnboundName(self.to_global_span(lhs.span))), + None => { + return Err(Error::UnboundName(self.to_global_span(lhs.span))); + } }, (ExprKind::Tuple(var_tup), Value::Tuple(tup, _)) => { for (expr, val) in var_tup.iter().zip(tup.iter()) { diff --git a/source/compiler/qsc_eval/src/tests.rs b/source/compiler/qsc_eval/src/tests.rs index 0634e8ca46..f02bd00e60 100644 --- a/source/compiler/qsc_eval/src/tests.rs +++ b/source/compiler/qsc_eval/src/tests.rs @@ -212,6 +212,31 @@ fn block_empty_is_unit_expr() { check_expr("", "{}", &expect!["()"]); } +#[test] +fn qubit_array_length_expr() { + check_expr( + "", + indoc! {"{ + use qs = Qubit[4]; + Length(qs) + }"}, + &expect!["4"], + ); +} + +#[test] +fn qubit_array_chunks_expr() { + check_expr( + "", + indoc! {"{ + use qs = Qubit[4]; + let chunks = Std.Arrays.Chunks(2, qs); + Length(chunks[0]) + }"}, + &expect!["2"], + ); +} + #[test] fn block_shadowing_expr() { check_expr( diff --git a/source/compiler/qsc_fir/src/assigner.rs b/source/compiler/qsc_fir/src/assigner.rs index 1676a82ef6..04f5ed9dc1 100644 --- a/source/compiler/qsc_fir/src/assigner.rs +++ b/source/compiler/qsc_fir/src/assigner.rs @@ -1,7 +1,33 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -use crate::fir::{BlockId, ExprId, LocalVarId, NodeId, PatId, StmtId}; +//! FIR node-ID allocator. +//! +//! [`Assigner`] provides monotonically increasing IDs for every FIR arena type +//! (`BlockId`, `StmtId`, `ExprId`, `PatId`, `LocalItemId`, `LocalVarId`, +//! `NodeId`). IDs are **never reused or decremented**. +//! +//! # Append-only arena contract +//! +//! FIR arenas (`Package.blocks`, `.stmts`, `.exprs`, `.pats`) are backed by +//! `IndexMap` which stores `Vec>`. FIR transform passes +//! create new nodes via `Assigner::next_*()` and may mutate existing nodes +//! in-place, but they **never remove entries** from the arenas. This means +//! pre-transform nodes remain as populated-but-unreachable entries ("orphans") +//! after transforms complete. +//! +//! Any code that iterates a FIR arena directly (via `IndexMap::iter()`) will +//! encounter orphan entries alongside live entries. Analyzers must either: +//! - Filter to reachable nodes before processing (see `qsc_rca::common`), or +//! - Tolerate orphan entries gracefully (e.g., in-place type mutations). +//! +//! The `gc_unreachable` pass in `qsc_fir_transforms` can tombstone orphan +//! entries after the pipeline completes, making `iter()` skip them. + +use crate::fir::{ + BlockId, CallableImpl, ExprId, ExprKind, LocalItemId, LocalVarId, NodeId, Package, PatId, + PatKind, Res, StmtId, +}; #[derive(Debug)] pub struct Assigner { @@ -12,6 +38,7 @@ pub struct Assigner { next_stmt: StmtId, next_local: LocalVarId, stashed_local: LocalVarId, + next_item: LocalItemId, } impl Assigner { @@ -25,6 +52,7 @@ impl Assigner { next_stmt: StmtId::default(), next_local: LocalVarId::default(), stashed_local: LocalVarId::default(), + next_item: LocalItemId::default(), } } @@ -64,6 +92,40 @@ impl Assigner { id } + pub fn next_item(&mut self) -> LocalItemId { + let id = self.next_item; + self.next_item = id.successor(); + id + } + + pub fn set_next_node(&mut self, id: NodeId) { + self.next_node = id; + } + + pub fn set_next_block(&mut self, id: BlockId) { + self.next_block = id; + } + + pub fn set_next_expr(&mut self, id: ExprId) { + self.next_expr = id; + } + + pub fn set_next_pat(&mut self, id: PatId) { + self.next_pat = id; + } + + pub fn set_next_stmt(&mut self, id: StmtId) { + self.next_stmt = id; + } + + pub fn set_next_local(&mut self, id: LocalVarId) { + self.next_local = id; + } + + pub fn set_next_item(&mut self, id: LocalItemId) { + self.next_item = id; + } + pub fn stash_local(&mut self) { self.stashed_local = self.next_local; self.next_local = LocalVarId::default(); @@ -73,6 +135,104 @@ impl Assigner { self.next_local = self.stashed_local; self.stashed_local = LocalVarId::default(); } + + /// Creates an `Assigner` whose counters are advanced past the maximum + /// existing IDs in `package`. + #[must_use] + pub fn from_package(package: &Package) -> Self { + let mut assigner = Self::new(); + + // BlockId + let max_block = package.blocks.iter().map(|(id, _)| u32::from(id)).max(); + if let Some(max) = max_block { + assigner.set_next_block(BlockId::from(max + 1)); + } + + // ExprId + let max_expr = package.exprs.iter().map(|(id, _)| u32::from(id)).max(); + if let Some(max) = max_expr { + assigner.set_next_expr(ExprId::from(max + 1)); + } + + // PatId + let max_pat = package.pats.iter().map(|(id, _)| u32::from(id)).max(); + if let Some(max) = max_pat { + assigner.set_next_pat(PatId::from(max + 1)); + } + + // StmtId + let max_stmt = package.stmts.iter().map(|(id, _)| u32::from(id)).max(); + if let Some(max) = max_stmt { + assigner.set_next_stmt(StmtId::from(max + 1)); + } + + // NodeId — scan callable and spec decls + let mut max_node: u32 = 0; + for item in package.items.values() { + if let crate::fir::ItemKind::Callable(decl) = &item.kind { + let decl_node: u32 = decl.id.into(); + max_node = max_node.max(decl_node); + Self::max_node_from_impl(&decl.implementation, &mut max_node); + } + } + assigner.set_next_node(NodeId::from(max_node + 1)); + + // LocalVarId — scan PatKind::Bind, ExprKind::Var(Res::Local), + // ExprKind::Closure + let mut max_local: u32 = 0; + for (_, pat) in &package.pats { + if let PatKind::Bind(ident) = &pat.kind { + let v: u32 = ident.id.into(); + max_local = max_local.max(v); + } + } + for (_, expr) in &package.exprs { + if let ExprKind::Var(Res::Local(var), _) = &expr.kind { + let v: u32 = (*var).into(); + max_local = max_local.max(v); + } + if let ExprKind::Closure(vars, _) = &expr.kind { + for var in vars { + let v: u32 = (*var).into(); + max_local = max_local.max(v); + } + } + } + assigner.set_next_local(LocalVarId::from(max_local + 1)); + + // LocalItemId — scan package.items keys + let max_item = package + .items + .iter() + .map(|(k, _)| -> usize { k.into() }) + .max(); + if let Some(max) = max_item { + assigner.set_next_item(LocalItemId::from(max + 1)); + } + + assigner + } + + fn max_node_from_impl(callable_impl: &CallableImpl, max_node: &mut u32) { + match callable_impl { + CallableImpl::Intrinsic => {} + CallableImpl::Spec(spec_impl) => { + let body_node: u32 = spec_impl.body.id.into(); + *max_node = (*max_node).max(body_node); + for spec in [&spec_impl.adj, &spec_impl.ctl, &spec_impl.ctl_adj] + .into_iter() + .flatten() + { + let n: u32 = spec.id.into(); + *max_node = (*max_node).max(n); + } + } + CallableImpl::SimulatableIntrinsic(spec) => { + let n: u32 = spec.id.into(); + *max_node = (*max_node).max(n); + } + } + } } impl Default for Assigner { diff --git a/source/compiler/qsc_fir/src/fir.rs b/source/compiler/qsc_fir/src/fir.rs index b73482c987..dd23b7dc70 100644 --- a/source/compiler/qsc_fir/src/fir.rs +++ b/source/compiler/qsc_fir/src/fir.rs @@ -563,7 +563,7 @@ pub trait PackageLookup { /// within the containing node. Node ids are used to identify nodes within /// the package and require mapping from the HIR node id to the new FIR node id. /// `PackageId`s and `LocalItemId`s are 1:1 from the HIR and are not remapped. -#[derive(Debug)] +#[derive(Debug, Default)] pub struct Package { /// The items in the package. pub items: IndexMap, @@ -937,7 +937,7 @@ impl ExecGraph { #[must_use] /// Selects the execution graph based on the configuration. - fn select_ref(&self, exec_graph_config: ExecGraphConfig) -> &ConfiguredExecGraph { + pub fn select_ref(&self, exec_graph_config: ExecGraphConfig) -> &ConfiguredExecGraph { match exec_graph_config { ExecGraphConfig::Debug => &self.debug, ExecGraphConfig::NoDebug => &self.no_debug, @@ -992,6 +992,13 @@ pub struct ExecGraphIdx { } impl ExecGraphIdx { + /// A zero-valued index, used as a placeholder for synthesized FIR nodes + /// that do not participate in the execution graph. + pub const ZERO: Self = Self { + no_debug_idx: 0, + debug_idx: 0, + }; + /// Selects the index based on the configuration. fn select(self, exec_graph_config: ExecGraphConfig) -> usize { match exec_graph_config { diff --git a/source/compiler/qsc_fir/src/ty.rs b/source/compiler/qsc_fir/src/ty.rs index d88ee98bd5..5cf148814f 100644 --- a/source/compiler/qsc_fir/src/ty.rs +++ b/source/compiler/qsc_fir/src/ty.rs @@ -465,6 +465,19 @@ impl FunctorSetValue { } } +impl FunctorSetValue { + /// Returns a compact identifier suitable for name mangling. + #[must_use] + pub fn mangle_name(&self) -> &'static str { + match self { + Self::Empty => "Empty", + Self::Adj => "Adj", + Self::Ctl => "Ctl", + Self::CtlAdj => "AdjCtl", + } + } +} + impl Display for FunctorSetValue { fn fmt(&self, f: &mut Formatter) -> fmt::Result { match self { diff --git a/source/compiler/qsc_fir_transforms/Cargo.toml b/source/compiler/qsc_fir_transforms/Cargo.toml new file mode 100644 index 0000000000..07db8826ce --- /dev/null +++ b/source/compiler/qsc_fir_transforms/Cargo.toml @@ -0,0 +1,46 @@ +[package] +name = "qsc_fir_transforms" + +version.workspace = true +authors.workspace = true +homepage.workspace = true +repository.workspace = true +edition.workspace = true +license.workspace = true + +[dependencies] +miette = { workspace = true } +thiserror = { workspace = true } +num-bigint = { workspace = true } +qsc_data_structures = { path = "../qsc_data_structures" } +qsc_fir = { path = "../qsc_fir" } +qsc_formatter = { path = "../qsc_formatter" } +qsc_frontend = { path = "../qsc_frontend", optional = true } +qsc_hir = { path = "../qsc_hir", optional = true } +qsc_lowerer = { path = "../qsc_lowerer" } +qsc_passes = { path = "../qsc_passes", optional = true } +rustc-hash = { workspace = true } + +[features] +slow-proptest-tests = [] +testutil = ["qsc_frontend", "qsc_hir", "qsc_passes"] + +[dev-dependencies] +qsc_fir_transforms = { path = ".", features = ["testutil"] } +expect-test = { workspace = true } +indoc = { workspace = true } +proptest = { workspace = true } +qsc_codegen = { path = "../qsc_codegen" } +qsc_eval = { path = "../qsc_eval" } +qsc_frontend = { path = "../qsc_frontend" } +qsc_hir = { path = "../qsc_hir" } +qsc_partial_eval = { path = "../qsc_partial_eval" } +qsc_parse = { path = "../qsc_parse" } +qsc_passes = { path = "../qsc_passes" } +qsc_rca = { path = "../qsc_rca" } + +[lints] +workspace = true + +[lib] +doctest = false diff --git a/source/compiler/qsc_fir_transforms/README.md b/source/compiler/qsc_fir_transforms/README.md new file mode 100644 index 0000000000..657785061e --- /dev/null +++ b/source/compiler/qsc_fir_transforms/README.md @@ -0,0 +1,141 @@ +# Overview + +`qsc_fir_transforms` owns the production FIR-to-FIR rewrite schedule that runs +after FIR lowering and before downstream consumers such as partial evaluation +and backend code generation. + +The passes in this crate are ordered and staged as one pipeline. They are not +intended to be individually sound in arbitrary combinations. Some intermediate +results are only valid because later passes restore the structural guarantees +that downstream code expects. + +Most rewrites are entry-reachability-driven. They inspect the code that can be +reached from the package entry expression and limit mutation accordingly. The +main exception is UDT erasure, which is still reachability-scoped but operates +at package granularity within the reachable package closure: it rewrites the +target package plus any package that contains an entry-reachable callable, +leaves unreachable packages untouched, and resolves UDT definitions from the +whole store. + +## Public entry point + +`run_pipeline` is the public production entry point. It runs the full rewrite +schedule on one FIR package and returns pipeline diagnostics produced by +`return_unify` or `defunctionalize`. + +Inside the crate, `run_pipeline_to` provides stage cut points for tests. The +`Sroa` and `ExecGraphRebuild` cut points are test-only conveniences. Production +code uses the full schedule. + +## Pipeline + +The authoritative pass order is: + +1. `monomorphize` +2. `return_unify` +3. `defunctionalize` +4. `udt_erase` +5. `tuple_compare_lower` +6. `sroa` +7. `arg_promote` +8. `gc_unreachable` +9. `item_dce` +10. `exec_graph_rebuild` + +The passes have the following responsibilities: + +1. `monomorphize` specializes reachable generic callables to the concrete types + used from the entry expression. +2. `return_unify` rewrites callable bodies to a single-exit form by + eliminating all `ExprKind::Return` nodes while preserving path-local side + effects such as qubit release calls. +3. `defunctionalize` eliminates callable-valued expressions and rewrites call + sites to direct callable references where possible. +4. `udt_erase` replaces UDT-typed values and struct expressions in the + reachable package closure with their pure tuple or scalar representation. +5. `tuple_compare_lower` lowers equality and inequality on non-empty tuples to + element-wise scalar comparisons. +6. `sroa` iteratively decomposes tuple-valued locals when every use is a field + access or another decomposable aggregate update. +7. `arg_promote` iteratively decomposes tuple-valued callable parameters and + updates reachable call sites. +8. `gc_unreachable` tombstones orphaned blocks, stmts, exprs, and pats that + are no longer reachable from any callable body or entry expression. +9. `item_dce` removes unreachable callable and type items left behind by + monomorphization and defunctionalization. +10. `exec_graph_rebuild` recomputes exec-graph metadata after earlier passes + synthesize new FIR nodes. + +Invariant checks run after `monomorphize`, `return_unify`, `defunctionalize`, +`udt_erase`, `tuple_compare_lower`, `sroa`, `arg_promote`, and +`gc_unreachable`, and then once more after `exec_graph_rebuild` when the full +pipeline completes. The `item_dce` pass does not have a dedicated invariant +check; the final `PostAll` check covers its effects. + +## Module guide + +* `src/lib.rs` defines the production schedule, the stage cut points used by + crate tests, and the shared pipeline contract. +* `src/monomorphize.rs`, `src/return_unify.rs`, `src/defunctionalize.rs`, + `src/udt_erase.rs`, `src/tuple_compare_lower.rs`, `src/sroa.rs`, + `src/arg_promote.rs`, `src/gc_unreachable.rs`, `src/item_dce.rs`, and + `src/exec_graph_rebuild.rs` implement the ordered transform stages. +* `src/invariants.rs` defines the staged structural checks that validate + intermediate and final pipeline states. +* `src/reachability.rs` computes the entry-reachable callable set shared by + multiple passes. +* `src/walk_utils.rs` provides traversal, use-collection, and ID-allocation + helpers for passes that rewrite FIR in place. +* `src/cloner.rs` provides reusable deep-cloning support for passes that need + to synthesize FIR while preserving consistent ID remapping. +* `src/pretty.rs` provides a FIR-to-Q# pretty-printer used by before/after + snapshot tests for pass debugging. +* `src/test_utils.rs` provides crate-local helpers that compile Q# snippets, + lower them to FIR, and run the authoritative schedule to an intermediate + stage. + +## Transformation shapes + +| Pass | Before | After | +|------|--------|-------| +| `monomorphize` | Generic callables with `Ty::Param` and non-empty generic-argument lists | Concrete callables; all `Ty::Param` resolved, generic-argument lists empty | +| `return_unify` | Multiple `ExprKind::Return` nodes in callable bodies, including raw qubit-release return wrappers | Single-exit form; no `Return` nodes remain in reachable code, and path-local releases stay on their original paths | +| `defunctionalize` | Arrow-typed parameters, closures, indirect callable dispatch | Direct dispatch only; no `ExprKind::Closure` or arrow-typed params in reachable code | +| `udt_erase` | `Ty::Udt` values, `ExprKind::Struct`, `Field::Path` in update/assign | Pure tuple or scalar representations; no UDT surface in reachable package closure | +| `tuple_compare_lower` | `BinOp(Eq/Neq)` on non-empty tuple-typed operands | Element-wise scalar `AndL`/`OrL` chains with `Field` extractions | +| `sroa` | Tuple-valued locals used only via field access | Decomposed scalar bindings; tuple binding replaced by per-field `PatKind::Bind` | +| `arg_promote` | Tuple-valued callable parameters | Flattened scalar parameters; call sites pass individual fields | +| `gc_unreachable` | Orphaned arena nodes (blocks, stmts, exprs, pats) from earlier rewrites | Tombstoned entries; only nodes reachable from package items or the entry expression survive | +| `item_dce` | Unreachable callable/type items (original generics, dead closure items) | Items removed from `Package::items`; `gc_unreachable` re-runs if items were deleted | +| `exec_graph_rebuild` | Stale `exec_graph_range` with `EMPTY_EXEC_RANGE` sentinels | Fresh exec graphs rebuilt from the rewritten FIR tree | + +## Testing + +The crate uses both pass-local unit tests and end-to-end integration tests. + +* Unit tests live next to each pass and focus on localized rewrites, + invariants, and edge cases. +* `src/invariants/tests.rs` adds mutation-style coverage for staged structural + guarantees. +* `tests/pipeline_integration.rs` compiles Q# snippets through the full + pipeline, compares the public `run_pipeline` wrapper with an explicit pass + schedule, and preserves targeted regression cases. +* The integration tests can call the public `run_pipeline_to` stage cut points, + but still duplicate the pass order intentionally when they need explicit + parity checks against the production schedule. + +## Test lanes + +The default test lane keeps deterministic tests enabled and excludes the slower +semantic-equivalence proptests: + +```bash +cargo test -p qsc_fir_transforms +``` + +Enable the slower proptest-backed semantic-equivalence suites with the +`slow-proptest-tests` feature: + +```bash +cargo test -p qsc_fir_transforms --features slow-proptest-tests +``` diff --git a/source/compiler/qsc_fir_transforms/src/arg_promote.rs b/source/compiler/qsc_fir_transforms/src/arg_promote.rs new file mode 100644 index 0000000000..8e86afc993 --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/arg_promote.rs @@ -0,0 +1,1351 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! Argument promotion pass. +//! +//! Decomposes tuple-typed parameters of callable declarations into individual +//! scalar parameters, eliminating intermediate tuple allocations at call sites +//! and field-access overhead in callable bodies. +//! +//! Establishes [`crate::invariants::InvariantLevel::PostArgPromote`]: +//! synthesized callable input tuple patterns agree with their +//! input types. +//! +//! For each entry-reachable callable, the pass: +//! - Identifies parameters bound via `PatKind::Bind(p)` with `Ty::Tuple(...)` +//! or `Ty::Udt(Res::Item(_))` where every use in every specialization body +//! is a field access. +//! - Verifies the callable is not used as a first-class value, referenced +//! as a closure target, or otherwise left indirectly dispatched. +//! First-class detection and closure-target detection together cover +//! the partial-application cases that used to be enumerated separately. +//! - Decomposes the binding in `CallableDecl.input` and rewrites field +//! accesses in all specialization bodies. +//! - Rewrites all call sites to pass individual fields instead of the whole +//! tuple/struct argument. +//! +//! Callables that appear as first-class values (a `Var(Res::Item(_))` with +//! `Ty::Arrow` outside the direct callee position of a `Call`) or as closure +//! targets in reachable code are disqualified, because their parameter layout +//! must remain stable for indirect invocation. +//! +//! The pass iterates to a fixed point, peeling one level of tuple nesting +//! per iteration (identical to SROA's iterative strategy). +//! +//! # Pipeline position +//! +//! This pass runs after SROA and before unreachable-node GC. At this point, +//! tuple-heavy parameter shapes are already simplified by earlier passes, and +//! argument promotion can rewrite callable signatures and direct call sites +//! without fighting major later structural rewrites. +//! +//! # Architecture +//! +//! One fixed-point iteration performs: +//! +//! 1. **Reachability scan** ([`collect_reachable_from_entry`]): +//! Limit work to entry-reachable callables. +//! 2. **Eligibility analysis** ([`check_candidates`]): +//! Find `PatKind::Bind` inputs of tuple/UDT shape whose uses are field-only +//! across every specialization. +//! 3. **Safety filters** ([`collect_first_class_callables`], +//! [`collect_closure_targets`]): +//! Exclude callables used as first-class values or closure targets. +//! 4. **Signature/body rewrite** ([`promote_candidate`]): +//! Replace promoted bind patterns with tuple patterns over fresh scalar +//! locals and rewrite body uses to read those locals. +//! 5. **Call-site rewrite** ([`rewrite_call_sites`]): +//! Rewrite direct call arguments to match promoted input shapes. +//! +//! After the fixed point converges, [`normalize_call_arg_types`] performs a +//! package-wide call-shape normalization pass to ensure argument expression +//! types exactly match callable input types (for example, +//! `T` to `(T)` wrapping for single-element tuple inputs). +//! +//! # Input patterns +//! +//! - `operation Foo(p : (Int, Qubit)) { use(p::0); apply(p::1); }` — a +//! tuple-typed parameter whose every use is a field projection. +//! +//! # Rewrites +//! +//! ```text +//! // Before +//! operation Foo(p : (Int, Qubit)) { use(p::0); apply(p::1); } +//! Foo((42, q)); +//! +//! // After +//! operation Foo(p_0 : Int, p_1 : Qubit) { use(p_0); apply(p_1); } +//! Foo(42, q); +//! ``` +//! +//! Nested fixed-point example: +//! +//! ```text +//! // Before +//! operation Foo(p : ((Int, Bool), Qubit)) : Unit { +//! let x = p::0::0; +//! let b = p::0::1; +//! let q = p::1; +//! if b { X(q); } +//! } +//! +//! // After first promotion pass +//! operation Foo(p_0 : (Int, Bool), p_1 : Qubit) : Unit { +//! let x = p_0::0; +//! let b = p_0::1; +//! let q = p_1; +//! if b { X(q); } +//! } +//! +//! // After fixed-point convergence +//! operation Foo(p_0_0 : Int, p_0_1 : Bool, p_1 : Qubit) : Unit { +//! let x = p_0_0; +//! let b = p_0_1; +//! let q = p_1; +//! if b { X(q); } +//! } +//! ``` +//! +//! Single-element tuple call-shape normalization example: +//! +//! ```text +//! // Callable input after prior rewrites/promotions +//! operation UseOne(p : (Qubit[])) : Unit { ... } +//! +//! // Call expression before normalization (type mismatch) +//! UseOne(qs); // arg type: Qubit[] +//! +//! // Call expression after normalization +//! UseOne((qs,)); // arg type: (Qubit[]) +//! ``` +//! +//! # Notes +//! +//! - Synthesized expressions use `EMPTY_EXEC_RANGE`; +//! [`crate::exec_graph_rebuild`] rebuilds correct exec graphs at the end +//! of the pipeline. +//! +//! # References +//! +//! This pass is named after LLVM's `ArgumentPromotion` pass (also +//! `argpromotion`), which promotes pointer arguments to pass-by-value. +//! This Q# variant operates on tuple aggregates rather than pointers. +//! +//! + +#[cfg(test)] +mod tests; + +#[cfg(all(test, feature = "slow-proptest-tests"))] +mod semantic_equivalence_tests; + +use crate::EMPTY_EXEC_RANGE; +use crate::fir_builder::{ + decompose_binding, functored_specs, reachable_local_callables, resolve_udt_element_types, +}; +use crate::reachability::collect_reachable_from_entry; +use crate::walk_utils::{ + collect_expr_ids_in_entry_and_local_callables, collect_expr_ids_in_local_callables, + collect_uses_in_block, for_each_expr, for_each_expr_in_callable_impl, +}; +use qsc_data_structures::span::Span; +use qsc_fir::assigner::Assigner; +use qsc_fir::fir::{ + Block, BlockId, CallableDecl, CallableImpl, Expr, ExprId, ExprKind, Field, FieldPath, Ident, + ItemKind, LocalItemId, LocalVarId, Mutability, Package, PackageId, PackageLookup, PackageStore, + Pat, PatId, PatKind, Res, SpecDecl, SpecImpl, Stmt, StmtId, StmtKind, StoreItemId, +}; +use qsc_fir::ty::Ty; +use rustc_hash::FxHashSet; +use std::rc::Rc; + +/// Runs argument promotion on the entry-reachable portion of a package. +/// +/// # Before +/// ```text +/// operation Foo(p : (Int, Qubit)) : Unit { use(p::0); apply(p::1); } +/// Foo((42, q)); +/// ``` +/// +/// # After +/// ```text +/// operation Foo(p_0 : Int, p_1 : Qubit) : Unit { use(p_0); apply(p_1); } +/// Foo(42, q); +/// ``` +/// +/// # Requires +/// - `package_id` exists in `store`. +/// - `assigner` is the pipeline-global assigner (ID continuity across passes). +/// +/// # Ensures +/// - Rewrites only entry-reachable callables. +/// - Leaves first-class and closure-target callables unchanged. +/// - Normalizes call argument shapes to match callable input types via +/// [`normalize_call_arg_types`]. +/// +/// # Mutations +/// - Rewrites callable input patterns and specialization bodies. +/// - Rewrites direct call expressions targeting promoted callables. +/// - Allocates fresh FIR nodes via `assigner` with `EMPTY_EXEC_RANGE`. +pub fn arg_promote(store: &mut PackageStore, package_id: PackageId, assigner: &mut Assigner) { + let package = store.get(package_id); + if package.entry.is_none() { + return; + } + + promote_to_fixed_point(store, package_id, assigner); + normalize_reachable_call_arg_types(store, package_id, assigner); +} + +/// Iterates promotion rounds until no more candidates are found. +/// +/// Each iteration peels one level of tuple nesting from eligible parameters, +/// rewrites their bodies and call sites, then recomputes reachability for +/// the next round. +fn promote_to_fixed_point( + store: &mut PackageStore, + package_id: PackageId, + assigner: &mut Assigner, +) { + loop { + let candidates = find_promotion_candidates(store, package_id); + if candidates.is_empty() { + break; + } + apply_promotions(store, package_id, assigner, &candidates); + } +} + +/// Finds all eligible promotion candidates in the current reachable set, +/// excluding callables used as first-class values or closure targets. +fn find_promotion_candidates( + store: &PackageStore, + package_id: PackageId, +) -> Vec { + let reachable = collect_reachable_from_entry(store, package_id); + let package = store.get(package_id); + + let first_class = collect_first_class_callables(package, package_id, &reachable); + let closure_targets = collect_closure_targets(package, package_id, &reachable); + + let mut candidates: Vec = Vec::new(); + for (item_id, decl) in reachable_local_callables(package, package_id, &reachable) { + if first_class.contains(&item_id) || closure_targets.contains(&item_id) { + continue; + } + candidates.extend(check_candidates(store, package, package_id, item_id, decl)); + } + candidates +} + +/// Applies a batch of promotion candidates: decomposes parameters, rewrites +/// bodies, and rewrites call sites scoped to reachable expressions. +fn apply_promotions( + store: &mut PackageStore, + package_id: PackageId, + assigner: &mut Assigner, + candidates: &[ArgPromoCandidate], +) { + let reachable = collect_reachable_from_entry(store, package_id); + let package = store.get(package_id); + let local_item_ids: Vec<_> = reachable_local_callables(package, package_id, &reachable) + .map(|(id, _)| id) + .collect(); + let reachable_expr_ids = + collect_expr_ids_in_entry_and_local_callables(package, &local_item_ids); + + let package = store.get_mut(package_id); + let mut promotions: Vec = Vec::new(); + for candidate in candidates { + if let Some(result) = promote_candidate(package, assigner, candidate) { + promotions.push(result); + } + } + + if !promotions.is_empty() { + rewrite_call_sites( + package, + package_id, + assigner, + &promotions, + &reachable_expr_ids, + ); + } +} + +/// Normalizes call-argument types across all reachable call sites after +/// promotion has converged. +fn normalize_reachable_call_arg_types( + store: &mut PackageStore, + package_id: PackageId, + assigner: &mut Assigner, +) { + let reachable = collect_reachable_from_entry(store, package_id); + let package = store.get(package_id); + let local_item_ids: Vec<_> = reachable_local_callables(package, package_id, &reachable) + .map(|(id, _)| id) + .collect(); + let reachable_expr_ids = + collect_expr_ids_in_entry_and_local_callables(package, &local_item_ids); + let package = store.get_mut(package_id); + normalize_call_arg_types(package, package_id, assigner, &reachable_expr_ids); +} + +/// A candidate for argument promotion. +struct ArgPromoCandidate { + /// The `LocalItemId` of the callable. + item_id: LocalItemId, + /// The `LocalVarId` bound by the parameter. + local_id: LocalVarId, + /// The `PatId` of the input binding pattern. + pat_id: PatId, + /// Element types from the tuple. + elem_types: Vec, + /// The name of the original parameter. + name: Rc, + /// Whether this is a top-level promotion (`pat_id == decl.input`). + /// Top-level promotions require call-site rewriting; sub-parameter + /// promotions (inside a `PatKind::Tuple`) do not. + is_top_level: bool, +} + +/// Result of promoting a candidate — tracks the callable and its element types +/// so that call sites can be rewritten. +struct PromotionResult { + /// The callable's `LocalItemId`. + item_id: LocalItemId, + /// Element types. + elem_types: Vec, +} + +/// Checks whether a callable's input parameter is a single tuple-typed or +/// UDT-typed binding whose only uses in all specialization bodies are field +/// accesses. Also recurses into `PatKind::Tuple` sub-patterns to find +/// inner bindings eligible for promotion after a previous pass. +fn check_candidates( + store: &PackageStore, + package: &Package, + _package_id: PackageId, + item_id: LocalItemId, + decl: &CallableDecl, +) -> Vec { + let mut candidates = Vec::new(); + find_param_binds_in_pat( + store, + package, + item_id, + decl, + decl.input, + true, + &mut candidates, + ); + candidates +} + +/// Recursively walks a callable's input pattern to find `PatKind::Bind` nodes +/// with tuple or UDT types whose uses are all field accesses. +fn find_param_binds_in_pat( + store: &PackageStore, + package: &Package, + item_id: LocalItemId, + decl: &CallableDecl, + pat_id: PatId, + is_top_level: bool, + candidates: &mut Vec, +) { + let pat = package.get_pat(pat_id); + match &pat.kind { + PatKind::Bind(ident) => { + let elem_types = match &pat.ty { + Ty::Tuple(elems) if !elems.is_empty() => Some(elems.clone()), + Ty::Udt(Res::Item(udt_item_id)) => resolve_udt_element_types(store, udt_item_id), + _ => None, + }; + if let Some(elem_types) = elem_types { + let local_id = ident.id; + if all_param_uses_are_field_access(package, decl, local_id) { + candidates.push(ArgPromoCandidate { + item_id, + local_id, + pat_id, + elem_types, + name: ident.name.clone(), + is_top_level, + }); + } + } + } + PatKind::Tuple(sub_pats) => { + for &sub_pat_id in sub_pats { + find_param_binds_in_pat( + store, package, item_id, decl, sub_pat_id, false, candidates, + ); + } + } + PatKind::Discard => {} + } +} + +/// Returns `true` if every use of `local_id` across all specialization bodies +/// of the callable is a field access. +/// +/// Intrinsic callables short-circuit to `true`: they have no user body to +/// analyze for field-projection eligibility, so the callable parameter +/// layout for an intrinsic is considered trivially field-only. +fn all_param_uses_are_field_access( + package: &Package, + decl: &CallableDecl, + local_id: LocalVarId, +) -> bool { + match &decl.implementation { + CallableImpl::Intrinsic => true, + CallableImpl::Spec(spec_impl) => all_uses_in_spec_impl(package, spec_impl, local_id), + CallableImpl::SimulatableIntrinsic(spec) => all_uses_in_spec(package, spec, local_id), + } +} + +/// Returns `true` when every specialization (body, adjoint, controlled, +/// controlled-adjoint) uses `local_id` exclusively via field access. +fn all_uses_in_spec_impl(package: &Package, spec_impl: &SpecImpl, local_id: LocalVarId) -> bool { + if !all_uses_in_spec(package, &spec_impl.body, local_id) { + return false; + } + for spec in functored_specs(spec_impl) { + if !all_uses_in_spec(package, spec, local_id) { + return false; + } + } + true +} + +/// Returns `true` when every use of `local_id` in a single `SpecDecl` body +/// is a field access (per the classifier in [`collect_uses_in_block`]). +fn all_uses_in_spec(package: &Package, spec: &SpecDecl, local_id: LocalVarId) -> bool { + let mut uses = Vec::new(); + collect_uses_in_block(package, spec.block, local_id, &mut uses); + uses.iter().all(|u| *u) +} + +/// Collects all `LocalItemId`s of callables in this package that appear as +/// `Var(Res::Item(id))` with an `Arrow` type (i.e., used as a first-class +/// value rather than as the callee of `Call`). +fn collect_first_class_callables( + package: &Package, + package_id: PackageId, + reachable: &FxHashSet, +) -> FxHashSet { + let mut first_class = FxHashSet::default(); + + // Scan the entry expression. + if let Some(entry_id) = package.entry { + scan_first_class_in_expr(package, package_id, entry_id, &mut first_class); + } + + // Scan every reachable callable body. + for item_id in reachable { + if item_id.package != package_id { + continue; + } + let item = package.get_item(item_id.item); + if let ItemKind::Callable(decl) = &item.kind { + scan_first_class_in_callable(package, package_id, decl, &mut first_class); + } + } + + first_class +} + +fn scan_first_class_in_callable( + package: &Package, + package_id: PackageId, + decl: &CallableDecl, + first_class: &mut FxHashSet, +) { + match &decl.implementation { + CallableImpl::Intrinsic => {} + CallableImpl::Spec(spec_impl) => { + scan_first_class_in_block(package, package_id, spec_impl.body.block, first_class); + for spec in functored_specs(spec_impl) { + scan_first_class_in_block(package, package_id, spec.block, first_class); + } + } + CallableImpl::SimulatableIntrinsic(spec) => { + scan_first_class_in_block(package, package_id, spec.block, first_class); + } + } +} + +fn scan_first_class_in_block( + package: &Package, + package_id: PackageId, + block_id: BlockId, + first_class: &mut FxHashSet, +) { + let block = package.get_block(block_id); + for &stmt_id in &block.stmts { + let stmt = package.get_stmt(stmt_id); + match &stmt.kind { + StmtKind::Expr(e) | StmtKind::Semi(e) => { + scan_first_class_in_expr(package, package_id, *e, first_class); + } + StmtKind::Local(_, _, expr) => { + scan_first_class_in_expr(package, package_id, *expr, first_class); + } + StmtKind::Item(_) => {} + } + } +} + +/// Scans an expression tree. A `Var(Res::Item(id))` with `Ty::Arrow` is +/// considered first-class UNLESS it appears as the direct callee of a `Call`. +fn scan_first_class_in_expr( + package: &Package, + package_id: PackageId, + expr_id: ExprId, + first_class: &mut FxHashSet, +) { + let expr = package.get_expr(expr_id); + match &expr.kind { + ExprKind::Call(callee, args) => { + // The callee position is a direct call — don't mark it. + // But still recurse into the callee's sub-expressions + // (e.g., if callee is Field(...), that's not a direct Var). + let callee_expr = package.get_expr(*callee); + match &callee_expr.kind { + ExprKind::Var(Res::Item(_), _) => { + // Direct call — skip marking, but recurse into args. + } + ExprKind::UnOp(_, inner) => { + // Functor-applied call: check if inner is a direct item ref. + let inner_expr = package.get_expr(*inner); + if !matches!(inner_expr.kind, ExprKind::Var(Res::Item(_), _)) { + // Not a direct functor application — recurse into callee. + scan_first_class_in_expr(package, package_id, *callee, first_class); + } + // If inner IS a direct Var(Item), this is a direct functor-applied + // call (e.g., Adjoint Foo(args)) — don't mark as first-class. + } + _ => { + scan_first_class_in_expr(package, package_id, *callee, first_class); + } + } + scan_first_class_in_expr(package, package_id, *args, first_class); + } + ExprKind::Var(Res::Item(item_id), _) if matches!(&expr.ty, Ty::Arrow(_)) => { + if item_id.package == package_id { + first_class.insert(item_id.item); + } + } + // Recurse into all sub-expressions. + ExprKind::Array(es) | ExprKind::ArrayLit(es) | ExprKind::Tuple(es) => { + for &e in es { + scan_first_class_in_expr(package, package_id, e, first_class); + } + } + ExprKind::ArrayRepeat(a, b) + | ExprKind::Assign(a, b) + | ExprKind::AssignOp(_, a, b) + | ExprKind::BinOp(_, a, b) + | ExprKind::Index(a, b) + | ExprKind::AssignField(a, _, b) + | ExprKind::UpdateField(a, _, b) => { + scan_first_class_in_expr(package, package_id, *a, first_class); + scan_first_class_in_expr(package, package_id, *b, first_class); + } + ExprKind::AssignIndex(a, b, c) | ExprKind::UpdateIndex(a, b, c) => { + scan_first_class_in_expr(package, package_id, *a, first_class); + scan_first_class_in_expr(package, package_id, *b, first_class); + scan_first_class_in_expr(package, package_id, *c, first_class); + } + ExprKind::Block(block_id) => { + scan_first_class_in_block(package, package_id, *block_id, first_class); + } + ExprKind::Fail(e) | ExprKind::Field(e, _) | ExprKind::Return(e) | ExprKind::UnOp(_, e) => { + scan_first_class_in_expr(package, package_id, *e, first_class); + } + ExprKind::If(cond, body, otherwise) => { + scan_first_class_in_expr(package, package_id, *cond, first_class); + scan_first_class_in_expr(package, package_id, *body, first_class); + if let Some(e) = otherwise { + scan_first_class_in_expr(package, package_id, *e, first_class); + } + } + ExprKind::Range(s, st, e) => { + for x in [s, st, e].into_iter().flatten() { + scan_first_class_in_expr(package, package_id, *x, first_class); + } + } + ExprKind::Struct(_, copy, fields) => { + if let Some(c) = copy { + scan_first_class_in_expr(package, package_id, *c, first_class); + } + for fa in fields { + scan_first_class_in_expr(package, package_id, fa.value, first_class); + } + } + ExprKind::String(components) => { + for c in components { + if let qsc_fir::fir::StringComponent::Expr(e) = c { + scan_first_class_in_expr(package, package_id, *e, first_class); + } + } + } + ExprKind::While(cond, block_id) => { + scan_first_class_in_expr(package, package_id, *cond, first_class); + scan_first_class_in_block(package, package_id, *block_id, first_class); + } + ExprKind::Hole | ExprKind::Lit(_) | ExprKind::Var(_, _) | ExprKind::Closure(_, _) => {} + } +} + +/// Collects all `LocalItemId`s that are targets of `Closure(_, local_item_id)` +/// in the entry-reachable portion of the current package. +fn collect_closure_targets( + package: &Package, + package_id: PackageId, + reachable: &FxHashSet, +) -> FxHashSet { + let mut targets = FxHashSet::default(); + + if let Some(entry_id) = package.entry { + for_each_expr(package, entry_id, &mut |_expr_id, expr| { + if let ExprKind::Closure(_, local_item_id) = &expr.kind { + targets.insert(*local_item_id); + } + }); + } + + for item_id in reachable { + if item_id.package != package_id { + continue; + } + + let item = package.get_item(item_id.item); + if let ItemKind::Callable(decl) = &item.kind { + for_each_expr_in_callable_impl(package, &decl.implementation, &mut |_expr_id, expr| { + if let ExprKind::Closure(_, local_item_id) = &expr.kind { + targets.insert(*local_item_id); + } + }); + } + } + + targets +} + +/// Promotes a single candidate in-place: decomposes the input pattern and +/// rewrites field accesses in all specialization bodies. Returns a +/// `PromotionResult` only for top-level promotions (which require call-site +/// rewriting). +/// +/// # Before +/// ```text +/// input pat = Bind(p : (A, B)) +/// body: Field(Var(Local(p)), Path([0])); Field(Var(Local(p)), Path([1])) +/// ``` +/// # After +/// ```text +/// input pat = Tuple([Bind(p_0 : A), Bind(p_1 : B)]) +/// body: Var(Local(p_0)); Var(Local(p_1)) +/// ``` +/// +/// # Mutations +/// - Rewrites the input `Pat` from `Bind` to `Tuple` of per-element `Bind`s. +/// - Allocates new `LocalVarId`, `PatId` nodes through `assigner`. +/// - Delegates to [`rewrite_field_accesses`] to rewrite body expressions. +fn promote_candidate( + package: &mut Package, + assigner: &mut Assigner, + candidate: &ArgPromoCandidate, +) -> Option { + let new_locals = decompose_binding( + package, + assigner, + candidate.pat_id, + &candidate.name, + &candidate.elem_types, + ); + + // Rewrite field accesses scoped to the promoted callable's body. + rewrite_field_accesses( + package, + assigner, + candidate.item_id, + candidate.local_id, + &new_locals, + &candidate.elem_types, + ); + + if candidate.is_top_level { + Some(PromotionResult { + item_id: candidate.item_id, + elem_types: candidate.elem_types.clone(), + }) + } else { + None + } +} + +/// Rewrites field accesses on the old local to use the new decomposed locals. +/// +/// Scoped to the promoted callable's body expressions only, since `old_local` +/// is a parameter binding that can only appear in the declaring callable. +/// +/// # Before +/// ```text +/// Field(Var(Local(old)), Path([i])) // param.i +/// ``` +/// # After +/// ```text +/// Var(Local(old_i)) // direct scalar reference +/// ``` +/// +/// # Mutations +/// - Rewrites `Expr.kind` in place for matching `Field` and `AssignField` +/// expressions via [`rewrite_single_field_expr`]. +fn rewrite_field_accesses( + package: &mut Package, + assigner: &mut Assigner, + item_id: LocalItemId, + old_local: LocalVarId, + new_locals: &[LocalVarId], + elem_types: &[Ty], +) { + let expr_ids = collect_expr_ids_in_local_callables(&*package, &[item_id]); + for expr_id in expr_ids { + rewrite_single_field_expr( + package, assigner, expr_id, old_local, new_locals, elem_types, + ); + } +} + +/// Rewrites a single expression that projects a field of the now-promoted +/// parameter so it references the corresponding new scalar parameter +/// binding directly. +/// +/// Handles two expression shapes: +/// +/// # Before (`Field` read) +/// ```text +/// Field(Var(Local(old_param)), Path([i])) // single-index +/// Field(Var(Local(old_param)), Path([i, j])) // nested +/// ``` +/// # After (`Field` read) +/// ```text +/// Var(Local(param_i)) // single-index: direct +/// Field(Var(Local(param_i)), Path([j])) // nested: re-rooted +/// ``` +/// +/// # Before (`AssignField`) +/// ```text +/// AssignField(Var(Local(old_param)), Path([i]), value) +/// ``` +/// # After (`AssignField`) +/// ```text +/// Assign(Var(Local(param_i)), value) +/// ``` +/// +/// # Mutations +/// - Rewrites `Expr.kind` and `Expr.ty` in place for the matched expression. +/// - Allocates new `Var` `Expr` nodes through `assigner` for nested and +/// assign-field paths. +fn rewrite_single_field_expr( + package: &mut Package, + assigner: &mut Assigner, + expr_id: ExprId, + old_local: LocalVarId, + new_locals: &[LocalVarId], + elem_types: &[Ty], +) { + let expr = package.exprs.get(expr_id).expect("expr should exist"); + match expr.kind.clone() { + ExprKind::Field(inner_id, Field::Path(path)) => { + let inner = package + .exprs + .get(inner_id) + .expect("inner expr should exist"); + if let ExprKind::Var(Res::Local(var_id), _) = &inner.kind + && *var_id == old_local + && !path.indices.is_empty() + { + let idx = path.indices[0]; + if idx < new_locals.len() { + if path.indices.len() == 1 { + let new_local = new_locals[idx]; + let new_ty = elem_types[idx].clone(); + let expr_mut = package.exprs.get_mut(expr_id).expect("expr exists"); + expr_mut.kind = ExprKind::Var(Res::Local(new_local), vec![]); + expr_mut.ty = new_ty; + } else { + let new_local = new_locals[idx]; + let remaining: Vec = path.indices[1..].to_vec(); + + let new_inner_id = assigner.next_expr(); + package.exprs.insert( + new_inner_id, + Expr { + id: new_inner_id, + span: Span::default(), + ty: elem_types[idx].clone(), + kind: ExprKind::Var(Res::Local(new_local), vec![]), + exec_graph_range: EMPTY_EXEC_RANGE, + }, + ); + + let expr_mut = package.exprs.get_mut(expr_id).expect("expr exists"); + expr_mut.kind = ExprKind::Field( + new_inner_id, + Field::Path(FieldPath { indices: remaining }), + ); + } + } + } + } + ExprKind::AssignField(record_id, Field::Path(path), value_id) => { + let record = package + .exprs + .get(record_id) + .expect("record expr should exist"); + if let ExprKind::Var(Res::Local(var_id), _) = &record.kind + && *var_id == old_local + && !path.indices.is_empty() + { + let idx = path.indices[0]; + if idx < new_locals.len() && path.indices.len() == 1 { + let new_local = new_locals[idx]; + + let new_record_id = assigner.next_expr(); + package.exprs.insert( + new_record_id, + Expr { + id: new_record_id, + span: Span::default(), + ty: elem_types[idx].clone(), + kind: ExprKind::Var(Res::Local(new_local), vec![]), + exec_graph_range: EMPTY_EXEC_RANGE, + }, + ); + + let expr_mut = package.exprs.get_mut(expr_id).expect("expr exists"); + expr_mut.kind = ExprKind::Assign(new_record_id, value_id); + } + } + } + _ => {} + } +} + +/// Rewrites all call sites for promoted callables. At each `Call(Var(Item(id)), +/// arg)` where `id` is a promoted callable, replaces the single tuple argument +/// with explicit field extractions wrapped in a `Tuple`. +/// +/// # Before +/// ```text +/// Foo(struct_arg) // single composite argument +/// ``` +/// # After +/// ```text +/// Foo((struct_arg.0, struct_arg.1)) // explicit field projections +/// ``` +/// +/// # Mutations +/// - Rewrites call-site `Expr.kind` in place or wraps in a block when +/// a temporary is needed to avoid evaluating the argument multiple times. +/// - Allocates field-projection and tuple `Expr` nodes through `assigner`. +fn rewrite_call_sites( + package: &mut Package, + package_id: PackageId, + assigner: &mut Assigner, + promotions: &[PromotionResult], + reachable_expr_ids: &[ExprId], +) { + // Build a set of promoted item IDs for quick lookup. + let promoted: FxHashSet = promotions.iter().map(|p| p.item_id).collect(); + + // Collect all call-site ExprIds that target a promoted callable. + let call_sites: Vec<(ExprId, LocalItemId)> = reachable_expr_ids + .iter() + .filter_map(|&expr_id| { + let expr = package.exprs.get(expr_id)?; + if let ExprKind::Call(callee_id, _) = &expr.kind { + let callee = package.exprs.get(*callee_id)?; + if let ExprKind::Var(Res::Item(item_id), _) = &callee.kind + && item_id.package == package_id + && promoted.contains(&item_id.item) + { + return Some((expr_id, item_id.item)); + } + } + None + }) + .collect(); + + for (call_expr_id, item_id) in call_sites { + let promotion = promotions + .iter() + .find(|p| p.item_id == item_id) + .expect("promotion should exist for promoted item"); + rewrite_single_call_site(package, assigner, call_expr_id, promotion); + } +} + +/// Returns `true` when an argument expression can be projected repeatedly +/// without side effects (e.g. literals, plain `Var` references), letting +/// the caller inline each projected field without introducing a +/// temporary. +fn expr_is_safe_to_project_repeatedly(package: &Package, expr_id: ExprId) -> bool { + match &package.get_expr(expr_id).kind { + ExprKind::Var(Res::Local(_), _) => true, + ExprKind::Field(inner_id, Field::Path(_)) => { + expr_is_safe_to_project_repeatedly(package, *inner_id) + } + _ => false, + } +} + +/// Creates a temporary `let temp = arg_expr;` binding for argument +/// expressions that cannot be projected repeatedly without +/// side-effect duplication. The caller replaces subsequent field +/// projections with references to `temp`. +/// +/// # Before +/// ```text +/// (no binding) +/// ``` +/// # After +/// ```text +/// let __arg_promote_tmp : T = arg_expr; +/// ``` +/// +/// # Mutations +/// - Allocates a new `Pat`, `LocalVarId`, and `Stmt` through `assigner`. +fn create_projection_temp_binding( + package: &mut Package, + assigner: &mut Assigner, + arg_id: ExprId, + arg_ty: &Ty, +) -> (LocalVarId, StmtId) { + let local_id = assigner.next_local(); + let pat_id = assigner.next_pat(); + package.pats.insert( + pat_id, + Pat { + id: pat_id, + span: Span::default(), + ty: arg_ty.clone(), + kind: PatKind::Bind(Ident { + id: local_id, + span: Span::default(), + name: Rc::from("__arg_promote_tmp"), + }), + }, + ); + + let stmt_id = assigner.next_stmt(); + package.stmts.insert( + stmt_id, + Stmt { + id: stmt_id, + span: Span::default(), + kind: StmtKind::Local(Mutability::Immutable, pat_id, arg_id), + exec_graph_range: EMPTY_EXEC_RANGE, + }, + ); + + (local_id, stmt_id) +} + +/// Allocates a fresh `ExprKind::Var(Res::Local(var))` expression with the +/// given type, used to materialize references to synthesized temporaries +/// and promoted parameters. +/// +/// # Mutations +/// - Inserts one `Expr` node through `assigner`. +fn create_local_var_expr( + package: &mut Package, + assigner: &mut Assigner, + local_id: LocalVarId, + ty: &Ty, +) -> ExprId { + let expr_id = assigner.next_expr(); + package.exprs.insert( + expr_id, + Expr { + id: expr_id, + span: Span::default(), + ty: ty.clone(), + kind: ExprKind::Var(Res::Local(local_id), vec![]), + exec_graph_range: EMPTY_EXEC_RANGE, + }, + ); + expr_id +} + +/// Builds the projected tuple that replaces the original tuple argument at +/// a call site, pairing each projected sub-expression with the type slot +/// expected by the promoted callable signature. +/// +/// # Before +/// ```text +/// (no expression) +/// ``` +/// # After +/// ```text +/// Tuple([Field(arg, Path([0])), ..., Field(arg, Path([n-1]))]) +/// ``` +/// +/// # Mutations +/// - Allocates per-element `Field` `Expr` nodes and the outer `Tuple` +/// `Expr` through `assigner`. +fn create_projected_tuple_arg( + package: &mut Package, + assigner: &mut Assigner, + promotion: &PromotionResult, + arg_id: ExprId, + arg_ty: &Ty, + temp_local: Option, +) -> ExprId { + let n = promotion.elem_types.len(); + let mut field_expr_ids: Vec = Vec::with_capacity(n); + + for i in 0..n { + let field_base_id = if let Some(temp_local) = temp_local { + create_local_var_expr(package, assigner, temp_local, arg_ty) + } else { + arg_id + }; + let field_expr_id = assigner.next_expr(); + let field_expr = qsc_fir::fir::Expr { + id: field_expr_id, + span: Span::default(), + ty: promotion.elem_types[i].clone(), + kind: ExprKind::Field(field_base_id, Field::Path(FieldPath { indices: vec![i] })), + exec_graph_range: EMPTY_EXEC_RANGE, + }; + package.exprs.insert(field_expr_id, field_expr); + field_expr_ids.push(field_expr_id); + } + + let new_arg_id = assigner.next_expr(); + let tuple_ty = Ty::Tuple(promotion.elem_types.clone()); + let new_arg = qsc_fir::fir::Expr { + id: new_arg_id, + span: Span::default(), + ty: tuple_ty, + kind: ExprKind::Tuple(field_expr_ids), + exec_graph_range: EMPTY_EXEC_RANGE, + }; + package.exprs.insert(new_arg_id, new_arg); + new_arg_id +} + +/// Wraps an existing `Call` expression in a synthesized block that places +/// a pre-built leading statement (typically a temporary binding) before +/// the call, preserving evaluation order. +/// +/// # Before +/// ```text +/// call_expr_id = Call(callee_id, _) +/// ``` +/// # After +/// ```text +/// call_expr_id = Block { +/// leading_stmt; // supplied by caller +/// Expr(Call(callee_id, new_arg_id)) // inner call with rewritten args +/// } +/// ``` +/// +/// # Mutations +/// - Replaces `call_expr_id`'s `ExprKind` with `Block(..)` in place. +/// - Allocates inner `Call`, `Stmt`, and `Block` nodes through `assigner`. +fn wrap_call_in_block( + package: &mut Package, + assigner: &mut Assigner, + call_expr_id: ExprId, + callee_id: ExprId, + new_arg_id: ExprId, + call_ty: &Ty, + leading_stmt_id: StmtId, +) { + let inner_call_id = assigner.next_expr(); + package.exprs.insert( + inner_call_id, + Expr { + id: inner_call_id, + span: Span::default(), + ty: call_ty.clone(), + kind: ExprKind::Call(callee_id, new_arg_id), + exec_graph_range: EMPTY_EXEC_RANGE, + }, + ); + + let call_stmt_id = assigner.next_stmt(); + package.stmts.insert( + call_stmt_id, + Stmt { + id: call_stmt_id, + span: Span::default(), + kind: StmtKind::Expr(inner_call_id), + exec_graph_range: EMPTY_EXEC_RANGE, + }, + ); + + let block_id = assigner.next_block(); + package.blocks.insert( + block_id, + Block { + id: block_id, + span: Span::default(), + ty: call_ty.clone(), + stmts: vec![leading_stmt_id, call_stmt_id], + }, + ); + + let call_mut = package + .exprs + .get_mut(call_expr_id) + .expect("call expr exists"); + call_mut.kind = ExprKind::Block(block_id); +} + +/// Rewrites a single call site: `Foo(arg)` → `Foo((arg.0, arg.1, ...))`. +/// +/// # Before +/// ```text +/// Call(Var(Foo), composite_arg) +/// ``` +/// # After +/// ```text +/// Call(Var(Foo), Tuple([arg.0, arg.1, ...])) // or Block wrapping +/// ``` +/// +/// If the argument is already a `Tuple(...)` with the correct arity, the +/// existing tuple elements are used directly. Otherwise, field-extraction +/// expressions are created. +/// +/// # Mutations +/// - Rewrites `call_expr_id`'s `ExprKind` in place. +/// - May allocate projection, tuple, and temporary `Expr`/`Stmt` nodes +/// through `assigner`. +fn rewrite_single_call_site( + package: &mut Package, + assigner: &mut Assigner, + call_expr_id: ExprId, + promotion: &PromotionResult, +) { + let call_expr = package.exprs.get(call_expr_id).expect("call expr exists"); + let ExprKind::Call(callee_id, arg_id) = call_expr.kind else { + return; + }; + let call_ty = call_expr.ty.clone(); + + let arg_expr = package.exprs.get(arg_id).expect("arg expr exists"); + let arg_ty = arg_expr.ty.clone(); + + // If the argument is already a tuple literal with matching arity, + // the call site is already structured correctly. + if let ExprKind::Tuple(elems) = &arg_expr.kind + && elems.len() == promotion.elem_types.len() + { + return; + } + + if promotion.elem_types.len() == 1 { + let new_arg_id = assigner.next_expr(); + let new_arg = qsc_fir::fir::Expr { + id: new_arg_id, + span: Span::default(), + ty: Ty::Tuple(promotion.elem_types.clone()), + kind: ExprKind::Tuple(vec![arg_id]), + exec_graph_range: EMPTY_EXEC_RANGE, + }; + package.exprs.insert(new_arg_id, new_arg); + + let call_mut = package + .exprs + .get_mut(call_expr_id) + .expect("call expr exists"); + call_mut.kind = ExprKind::Call(callee_id, new_arg_id); + return; + } + + let temp_binding = if expr_is_safe_to_project_repeatedly(package, arg_id) { + None + } else { + Some(create_projection_temp_binding( + package, assigner, arg_id, &arg_ty, + )) + }; + let new_arg_id = create_projected_tuple_arg( + package, + assigner, + promotion, + arg_id, + &arg_ty, + temp_binding.map(|(temp_local, _)| temp_local), + ); + + if let Some((_, temp_stmt_id)) = temp_binding { + wrap_call_in_block( + package, + assigner, + call_expr_id, + callee_id, + new_arg_id, + &call_ty, + temp_stmt_id, + ); + } else { + let call_mut = package + .exprs + .get_mut(call_expr_id) + .expect("call expr exists"); + call_mut.kind = ExprKind::Call(callee_id, new_arg_id); + } +} + +/// Normalizes call argument expression shapes to exactly match callee input +/// types. +/// +/// This pass is intentionally run after fixed-point promotion converges, +/// because prior rewrites can leave call arguments with shape-equivalent but +/// type-distinct forms (most notably `T` vs `(T)` for single-element tuples). +/// +/// # Before +/// ```text +/// operation UseOne(p : (Qubit[])) : Unit { ... } +/// UseOne(qs); // arg type Qubit[] +/// ``` +/// +/// # After +/// ```text +/// operation UseOne(p : (Qubit[])) : Unit { ... } +/// UseOne((qs,)); // arg type (Qubit[]) +/// ``` +/// +/// # Ensures +/// - For every direct call expression, argument type structure matches the +/// expected callable input type where normalization can be done locally. +/// - Does not rewrite callee declarations; only argument expression shape. +fn normalize_call_arg_types( + package: &mut Package, + package_id: PackageId, + assigner: &mut Assigner, + reachable_expr_ids: &[ExprId], +) { + let call_sites: Vec<(ExprId, Ty)> = reachable_expr_ids + .iter() + .filter_map(|&expr_id| { + let expr = package.exprs.get(expr_id)?; + let ExprKind::Call(callee_id, arg_id) = expr.kind else { + return None; + }; + resolve_expected_input(package, package_id, callee_id) + .map(|expected_input| (arg_id, expected_input)) + }) + .collect(); + + for (arg_id, expected_input) in call_sites { + normalize_arg_to_expected_input(package, assigner, arg_id, &expected_input); + } +} + +fn resolve_expected_input( + package: &Package, + package_id: PackageId, + callee_id: ExprId, +) -> Option { + let callee = package.get_expr(callee_id); + if let ExprKind::Var(Res::Item(item_id), _) = &callee.kind + && item_id.package == package_id + { + let item = package.items.get(item_id.item)?; + if let ItemKind::Callable(decl) = &item.kind { + return Some(package.get_pat(decl.input).ty.clone()); + } + } + + if let Ty::Arrow(arrow) = &callee.ty { + return Some((*arrow.input).clone()); + } + + None +} + +/// Reconciles a rewritten call-site argument subtree with the callee's current +/// input type. +/// +/// Before, `arg_id` may still reflect the pre-promotion shape, such as a scalar +/// where the promoted callee now expects `(scalar,)`, or nested tuple children +/// whose wrapper structure no longer matches the updated input pattern. After, +/// the subtree rooted at `arg_id` mirrors `expected_input`: single-element tuple +/// wrappers are inserted only where required and tuple types are refreshed after +/// recursive normalization. +fn normalize_arg_to_expected_input( + package: &mut Package, + assigner: &mut Assigner, + arg_id: ExprId, + expected_input: &Ty, +) { + let arg = package.get_expr(arg_id).clone(); + if arg.ty == *expected_input { + return; + } + + let Ty::Tuple(expected_items) = expected_input else { + return; + }; + + if expected_items.len() == 1 && arg.ty == expected_items[0] { + if matches!(&arg.kind, ExprKind::Tuple(items) if items.len() == 1) { + return; + } + wrap_arg_in_single_tuple(package, assigner, arg_id); + return; + } + + let ExprKind::Tuple(arg_items) = arg.kind else { + return; + }; + if arg_items.len() != expected_items.len() { + return; + } + + for (arg_item, expected_item) in arg_items.iter().zip(expected_items) { + normalize_arg_to_expected_input(package, assigner, *arg_item, expected_item); + } + + let updated_tys = arg_items + .iter() + .map(|arg_item| package.get_expr(*arg_item).ty.clone()) + .collect(); + let arg_mut = package.exprs.get_mut(arg_id).expect("arg expr exists"); + arg_mut.ty = Ty::Tuple(updated_tys); +} + +/// Replaces `arg_id` with a one-element tuple node while preserving the +/// original argument under a freshly allocated child expression. +/// +/// Before, `arg_id` points directly at the scalar or tuple element supplied at +/// the call site. After, the original payload lives at `preserved_arg_id` and +/// `arg_id` becomes `(payload)`, matching callees whose promoted signature still +/// expects a single tuple layer. +fn wrap_arg_in_single_tuple(package: &mut Package, assigner: &mut Assigner, arg_id: ExprId) { + let original_arg = package.get_expr(arg_id).clone(); + let preserved_arg_id = assigner.next_expr(); + package.exprs.insert( + preserved_arg_id, + Expr { + id: preserved_arg_id, + span: original_arg.span, + ty: original_arg.ty.clone(), + kind: original_arg.kind, + exec_graph_range: original_arg.exec_graph_range, + }, + ); + + let arg = package.exprs.get_mut(arg_id).expect("arg expr exists"); + arg.kind = ExprKind::Tuple(vec![preserved_arg_id]); + arg.ty = Ty::Tuple(vec![original_arg.ty]); +} diff --git a/source/compiler/qsc_fir_transforms/src/arg_promote/semantic_equivalence_tests.rs b/source/compiler/qsc_fir_transforms/src/arg_promote/semantic_equivalence_tests.rs new file mode 100644 index 0000000000..febc443c3d --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/arg_promote/semantic_equivalence_tests.rs @@ -0,0 +1,151 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use indoc::formatdoc; +use indoc::indoc; +use proptest::prelude::*; + +#[test] +fn tuple_param_flattened_preserves_semantics() { + crate::test_utils::check_semantic_equivalence(indoc! {r#" + namespace Test { + function Add(pair : (Int, Int)) : Int { + let (a, b) = pair; + a + b + } + + @EntryPoint() + function Main() : Int { + Add((3, 4)) + } + } + "#}); +} + +#[test] +fn nested_tuple_param_flattened_preserves_semantics() { + crate::test_utils::check_semantic_equivalence(indoc! {r#" + namespace Test { + function Sum(args : ((Int, Int), Int)) : Int { + let ((a, b), c) = args; + a + b + c + } + + @EntryPoint() + function Main() : Int { + Sum(((1, 2), 3)) + } + } + "#}); +} + +#[test] +fn mixed_scalar_and_tuple_params_preserves_semantics() { + crate::test_utils::check_semantic_equivalence(indoc! {r#" + namespace Test { + function Weighted(scale : Int, pair : (Int, Int)) : Int { + let (x, y) = pair; + scale * (x + y) + } + + @EntryPoint() + function Main() : Int { + Weighted(2, (5, 7)) + } + } + "#}); +} + +fn tuple_parameter_argument_pattern() -> impl Strategy { + (2usize..=4, prop::collection::vec(-20i64..=20, 4)).prop_map(|(width, argument_values)| { + let parameter_type = (0..width).map(|_| "Int").collect::>().join(", "); + let field_bindings = (0..width) + .map(|index| format!("field{index}")) + .collect::>() + .join(", "); + let arguments = argument_values + .into_iter() + .take(width) + .map(|value| value.to_string()) + .collect::>() + .join(", "); + + formatdoc! {r#" + namespace Test {{ + function ProjectFirst(parameter : ({parameter_type})) : Int {{ + let ({field_bindings}) = parameter; + field0 + }} + + @EntryPoint() + function Main() : Int {{ + ProjectFirst(({arguments})) + }} + }} + "#} + }) +} + +proptest! { + #![proptest_config(ProptestConfig::with_cases(50))] + + #[test] + fn tuple_parameter_argument_promotion_preserves_semantics(source in tuple_parameter_argument_pattern()) { + crate::test_utils::check_semantic_equivalence(&source); + } +} + +fn qsharp_bool(value: bool) -> &'static str { + if value { "true" } else { "false" } +} + +fn nested_mixed_struct_callable_strategy() -> impl Strategy { + ( + -20i64..=20, + prop::bool::ANY, + -20i64..=20, + prop::bool::ANY, + prop::bool::ANY, + ) + .prop_map(|(value, flag, bonus, enabled, prefer_alias)| { + let flag = qsharp_bool(flag); + let enabled = qsharp_bool(enabled); + let selector = qsharp_bool(prefer_alias); + + formatdoc! {r#" + namespace Test {{ + struct Inner {{ Value : Int, Flag : Bool }} + struct Outer {{ Left : Inner, Bonus : Int, Enabled : Bool }} + + function Sum(input : Outer) : Int {{ + let signed = if input.Left.Flag {{ input.Left.Value }} else {{ -input.Left.Value }}; + if input.Enabled {{ signed + input.Bonus }} else {{ signed - input.Bonus }} + }} + + @EntryPoint() + function Main() : Int {{ + let input = new Outer {{ + Left = new Inner {{ Value = {value}, Flag = {flag} }}, + Bonus = {bonus}, + Enabled = {enabled} + }}; + let f = Sum; + let viaAlias = f(input); + let direct = Sum(input); + if {selector} {{ viaAlias }} else {{ direct }} + }} + }} + "#} + }) +} + +proptest! { + #![proptest_config(ProptestConfig::with_cases(32))] + + #[test] + fn nested_mixed_struct_callable_arg_promotion_preserves_semantics( + source in nested_mixed_struct_callable_strategy() + ) { + crate::test_utils::check_semantic_equivalence(&source); + } +} diff --git a/source/compiler/qsc_fir_transforms/src/arg_promote/tests.rs b/source/compiler/qsc_fir_transforms/src/arg_promote/tests.rs new file mode 100644 index 0000000000..fd48c663b2 --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/arg_promote/tests.rs @@ -0,0 +1,987 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use super::*; +use crate::test_utils::{PipelineStage, compile_and_run_pipeline_to, compile_to_fir}; +use expect_test::{Expect, expect}; +use indoc::indoc; +use qsc_fir::assigner::Assigner; +use qsc_fir::fir::{ + CallableDecl, CallableImpl, ExprId, ExprKind, Field, FieldPath, ItemKind, LocalVarId, + Mutability, PackageLookup, PatKind, Res, StmtKind, +}; +use rustc_hash::FxHashMap; + +fn check(source: &str, expect: &Expect) { + let (store, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::ArgPromote); + let result = extract_result(&store, pkg_id); + expect.assert_eq(&result); +} + +fn extract_result(store: &PackageStore, pkg_id: PackageId) -> String { + let package = store.get(pkg_id); + let reachable = crate::reachability::collect_reachable_from_entry(store, pkg_id); + let mut entries: Vec = Vec::new(); + for store_id in &reachable { + if store_id.package != pkg_id { + continue; + } + let item = package.get_item(store_id.item); + if let ItemKind::Callable(decl) = &item.kind { + let mut lines = Vec::new(); + lines.push(format!( + "Callable {}: input={}", + decl.name.name, + format_pat(package, decl.input) + )); + if let CallableImpl::Spec(spec) = &decl.implementation { + let block = package.get_block(spec.body.block); + for &stmt_id in &block.stmts { + let stmt = package.get_stmt(stmt_id); + if let StmtKind::Local(mutability, pat_id, _) = &stmt.kind { + let mut_str = if matches!(mutability, Mutability::Mutable) { + "mutable " + } else { + "" + }; + lines.push(format!( + " local: {}{}", + mut_str, + format_pat(package, *pat_id) + )); + } + } + } + entries.push(lines.join("\n")); + } + } + entries.sort(); + entries.join("\n") +} + +fn format_pat(package: &qsc_fir::fir::Package, pat_id: PatId) -> String { + let pat = package.get_pat(pat_id); + match &pat.kind { + PatKind::Bind(ident) => format!("Bind({}: {})", ident.name, pat.ty), + PatKind::Tuple(sub_pats) => { + let subs: Vec = sub_pats.iter().map(|&id| format_pat(package, id)).collect(); + format!("Tuple({})", subs.join(", ")) + } + PatKind::Discard => format!("Discard({})", pat.ty), + } +} + +fn find_callable<'a>(package: &'a qsc_fir::fir::Package, callable_name: &str) -> &'a CallableDecl { + package + .items + .values() + .find_map(|item| match &item.kind { + ItemKind::Callable(decl) if decl.name.name.as_ref() == callable_name => { + Some(decl.as_ref()) + } + _ => None, + }) + .unwrap_or_else(|| panic!("callable '{callable_name}' not found")) +} + +fn local_names(package: &qsc_fir::fir::Package) -> FxHashMap { + package + .pats + .values() + .filter_map(|pat| match &pat.kind { + PatKind::Bind(ident) => Some((ident.id, ident.name.to_string())), + PatKind::Tuple(_) | PatKind::Discard => None, + }) + .collect() +} + +fn find_pat_binding_id_by_name( + package: &qsc_fir::fir::Package, + pat_id: PatId, + binding_name: &str, +) -> Option { + let pat = package.get_pat(pat_id); + match &pat.kind { + PatKind::Bind(ident) if ident.name.as_ref() == binding_name => Some(ident.id), + PatKind::Bind(_) | PatKind::Discard => None, + PatKind::Tuple(sub_pats) => sub_pats + .iter() + .find_map(|&sub_pat_id| find_pat_binding_id_by_name(package, sub_pat_id, binding_name)), + } +} + +fn item_name(package: &qsc_fir::fir::Package, item_id: &qsc_fir::fir::ItemId) -> String { + package + .items + .get(item_id.item) + .and_then(|item| match &item.kind { + ItemKind::Callable(decl) => Some(decl.name.name.to_string()), + _ => None, + }) + .unwrap_or_else(|| format!("{item_id:?}")) +} + +fn format_call_operand( + package: &qsc_fir::fir::Package, + names: &FxHashMap, + expr_id: ExprId, +) -> String { + let expr = package.get_expr(expr_id); + match &expr.kind { + ExprKind::Field(record_id, Field::Path(path)) => { + let mut formatted = format_call_operand(package, names, *record_id); + for index in &path.indices { + formatted.push('.'); + formatted.push_str(&index.to_string()); + } + formatted + } + ExprKind::Lit(lit) => format!("{lit:?}"), + ExprKind::Tuple(items) => { + let items = items + .iter() + .map(|item| format_call_operand(package, names, *item)) + .collect::>() + .join(", "); + format!("({items})") + } + ExprKind::UnOp(op, operand_id) => { + format!( + "{op:?}({})", + format_call_operand(package, names, *operand_id) + ) + } + ExprKind::Var(Res::Item(item_id), _) => item_name(package, item_id), + ExprKind::Var(Res::Local(local_id), _) => names + .get(local_id) + .cloned() + .unwrap_or_else(|| format!("{local_id:?}")), + _ => crate::test_utils::expr_kind_short(package, expr_id), + } +} + +fn extract_call_shapes(store: &PackageStore, pkg_id: PackageId, callable_name: &str) -> String { + let package = store.get(pkg_id); + let names = local_names(package); + let callable = find_callable(package, callable_name); + let mut calls = Vec::new(); + + crate::walk_utils::for_each_expr_in_callable_impl( + package, + &callable.implementation, + &mut |_expr_id, expr| { + if let ExprKind::Call(callee_id, arg_id) = expr.kind { + calls.push(format!( + "{}({})", + format_call_operand(package, &names, callee_id), + format_call_operand(package, &names, arg_id), + )); + } + }, + ); + + calls.join("\n") +} + +fn extract_field_access_shapes( + store: &PackageStore, + pkg_id: PackageId, + callable_name: &str, +) -> String { + let package = store.get(pkg_id); + let names = local_names(package); + let callable = find_callable(package, callable_name); + let mut accesses = Vec::new(); + + crate::walk_utils::for_each_expr_in_callable_impl( + package, + &callable.implementation, + &mut |expr_id, expr| { + if matches!(expr.kind, ExprKind::Field(_, Field::Path(_))) { + accesses.push(format_call_operand(package, &names, expr_id)); + } + }, + ); + + accesses.sort(); + accesses.dedup(); + accesses.join("\n") +} + +fn callable_body_block_id( + package: &qsc_fir::fir::Package, + callable_name: &str, +) -> qsc_fir::fir::BlockId { + let callable = find_callable(package, callable_name); + match &callable.implementation { + CallableImpl::Spec(spec) => spec.body.block, + CallableImpl::SimulatableIntrinsic(spec) => spec.block, + CallableImpl::Intrinsic => panic!("callable '{callable_name}' does not have a body"), + } +} + +fn expect_direct_item_call( + package: &qsc_fir::fir::Package, + expr_id: ExprId, + expected_callee: &str, +) -> ExprId { + let expr = package.get_expr(expr_id); + let ExprKind::Call(callee_id, arg_id) = &expr.kind else { + panic!("expected direct call expression, found {:?}", expr.kind); + }; + + let callee = package.get_expr(*callee_id); + let ExprKind::Var(Res::Item(item_id), _) = &callee.kind else { + panic!("expected direct item callee, found {:?}", callee.kind); + }; + + assert_eq!(item_name(package, item_id), expected_callee); + *arg_id +} + +fn force_shared_nested_field_inner_expr( + store: &mut PackageStore, + pkg_id: PackageId, + callable_name: &str, + binding_name: &str, +) { + let (shared_inner_id, first_field_expr_id, second_field_expr_id) = { + let package = store.get(pkg_id); + let callable = find_callable(package, callable_name); + let old_local = find_pat_binding_id_by_name(package, callable.input, binding_name) + .unwrap_or_else(|| { + panic!("binding '{binding_name}' not found in callable '{callable_name}'") + }); + + let qsc_fir::ty::Ty::Tuple(elem_tys) = &package.get_pat(callable.input).ty else { + panic!("callable '{callable_name}' input should be a tuple"); + }; + assert!( + matches!(elem_tys.first(), Some(qsc_fir::ty::Ty::Tuple(_))), + "callable '{callable_name}' input should keep a nested tuple in its first element" + ); + + let mut direct_fields = Vec::new(); + crate::walk_utils::for_each_expr_in_callable_impl( + package, + &callable.implementation, + &mut |expr_id, expr| { + if let ExprKind::Field(inner_id, Field::Path(path)) = &expr.kind { + let inner = package.get_expr(*inner_id); + if let ExprKind::Var(Res::Local(var_id), _) = &inner.kind + && *var_id == old_local + && !path.indices.is_empty() + { + direct_fields.push((expr_id, *inner_id)); + } + } + }, + ); + + assert!( + direct_fields.len() >= 2, + "expected at least two field accesses in callable '{callable_name}'" + ); + + let (first_field_expr_id, shared_inner_id) = &direct_fields[0]; + let (second_field_expr_id, _) = &direct_fields[1]; + ( + *shared_inner_id, + *first_field_expr_id, + *second_field_expr_id, + ) + }; + + let package = store.get_mut(pkg_id); + for (expr_id, indices) in [ + (first_field_expr_id, vec![0, 0]), + (second_field_expr_id, vec![0, 1]), + ] { + let expr = package + .exprs + .get_mut(expr_id) + .expect("aliased field expr should exist"); + expr.kind = ExprKind::Field(shared_inner_id, Field::Path(FieldPath { indices })); + } +} + +fn collect_pat_binding_names( + package: &qsc_fir::fir::Package, + pat_id: PatId, + names: &mut Vec, +) { + let pat = package.get_pat(pat_id); + match &pat.kind { + PatKind::Bind(ident) => names.push(ident.name.to_string()), + PatKind::Tuple(sub_pats) => { + for &sub_pat_id in sub_pats { + collect_pat_binding_names(package, sub_pat_id, names); + } + } + PatKind::Discard => {} + } +} + +#[test] +fn param_field_access_decomposes() { + check( + "struct Pair { X : Int, Y : Int } + function Foo(p : Pair) : Int { p.X + p.Y } + function Main() : Int { Foo(new Pair { X = 1, Y = 2 }) }", + &expect![[r#" + Callable Foo: input=Tuple(Bind(p_0: Int), Bind(p_1: Int)) + Callable Main: input=Tuple()"#]], + ); +} + +#[test] +fn call_site_rewritten_for_variable_arg() { + check( + "struct Pair { X : Int, Y : Int } + function Foo(p : Pair) : Int { p.X + p.Y } + function Main() : Int { + let s = new Pair { X = 10, Y = 20 }; + Foo(s) + }", + &expect![[r#" + Callable Foo: input=Tuple(Bind(p_0: Int), Bind(p_1: Int)) + Callable Main: input=Tuple() + local: Bind(s: (Int, Int))"#]], + ); +} + +#[test] +fn whole_param_use_skips_promotion() { + check( + "struct Pair { X : Int, Y : Int } + function Identity(p : Pair) : Pair { p } + function Main() : Int { + let r = Identity(new Pair { X = 1, Y = 2 }); + r.X + }", + &expect![[r#" + Callable Identity: input=Bind(p: (Int, Int)) + Callable Main: input=Tuple() + local: Tuple(Bind(r_0: Int), Bind(r_1: Int))"#]], + ); +} + +#[test] +fn triple_param_decomposes() { + check( + "struct Triple { A : Int, B : Int, C : Int } + function Sum(t : Triple) : Int { t.A + t.B + t.C } + function Main() : Int { Sum(new Triple { A = 1, B = 2, C = 3 }) }", + &expect![[r#" + Callable Main: input=Tuple() + Callable Sum: input=Tuple(Bind(t_0: Int), Bind(t_1: Int), Bind(t_2: Int))"#]], + ); +} + +#[test] +fn callable_with_empty_tuple_parameter() { + // Function with Unit parameter — should not crash, nothing to promote. + check( + "function Foo(u : Unit) : Int { 42 } + function Main() : Int { Foo(()) }", + &expect![[r#" + Callable Foo: input=Bind(u: Unit) + Callable Main: input=Tuple()"#]], + ); +} + +#[test] +fn callable_with_single_field_param() { + // Single-field struct parameters are still promoted. The callable input + // becomes a one-element tuple pattern and reachable call sites are + // rewritten to match. + check( + "struct Wrapper { Val : Int } + function Foo(w : Wrapper) : Int { w.Val } + function Main() : Int { Foo(new Wrapper { Val = 42 }) }", + &expect![[r#" + Callable Foo: input=Tuple(Bind(w_0: Int)) + Callable Main: input=Tuple()"#]], + ); +} + +#[test] +fn callable_with_nested_tuple_parameter() { + // Nested struct: outer struct's fields include another struct. + // Iterative arg_promote decomposes both the outer and inner + // parameters since the inner tuple's uses are field-only. + check( + "struct Inner { A : Int, B : Int } + struct Outer { Left : Inner, Extra : Int } + function Foo(o : Outer) : Int { o.Left.A + o.Extra } + function Main() : Int { + Foo(new Outer { Left = new Inner { A = 1, B = 2 }, Extra = 3 }) + }", + &expect![[r#" + Callable Foo: input=Tuple(Tuple(Bind(o_0_0: Int), Bind(o_0_1: Int)), Bind(o_1: Int)) + Callable Main: input=Tuple()"#]], + ); +} + +#[test] +fn operation_with_adj_spec() { + // Operation with Adj spec: adjoint body should also be updated + // when parameters are promoted. + check( + "struct Pair { X : Int, Y : Int } + operation Foo(p : Pair) : Unit is Adj { + body ... { + let _ = p.X + p.Y; + } + adjoint self; + } + operation Main() : Unit { + Foo(new Pair { X = 1, Y = 2 }); + }", + &expect![[r#" + Callable Foo: input=Tuple(Bind(p_0: Int), Bind(p_1: Int)) + local: Discard(Int) + Callable Main: input=Tuple()"#]], + ); +} + +#[test] +fn recursive_callable_with_tuple_parameter() { + // Recursive callable: self-call sites must be rewritten too. + check( + "struct Pair { X : Int, Y : Int } + function Loop(p : Pair, n : Int) : Int { + if n <= 0 { + p.X + p.Y + } else { + Loop(p, n - 1) + } + } + function Main() : Int { + Loop(new Pair { X = 1, Y = 2 }, 3) + }", + &expect![[r#" + Callable Loop: input=Tuple(Bind(p: (Int, Int)), Bind(n: Int)) + Callable Main: input=Tuple()"#]], + ); +} + +#[test] +fn callable_with_promoted_args_full_pipeline() { + // Full pipeline integration: SROA + arg_promote both run. + // Verifies the combined effect: locals decomposed AND params promoted. + check( + "struct Pair { X : Int, Y : Int } + function Add(p : Pair) : Int { p.X + p.Y } + function Main() : Int { + let a = new Pair { X = 10, Y = 20 }; + let b = new Pair { X = 30, Y = 40 }; + Add(a) + Add(b) + }", + &expect![[r#" + Callable Add: input=Tuple(Bind(p_0: Int), Bind(p_1: Int)) + Callable Main: input=Tuple() + local: Bind(a: (Int, Int)) + local: Bind(b: (Int, Int))"#]], + ); +} + +#[test] +fn functor_applied_callee_not_first_class() { + // Adjoint Op(args) is a direct functor-applied call, not a first-class use. + // Op's struct parameter should still be decomposed. + check( + "struct Pair { X : Int, Y : Int } + operation Op(p : Pair) : Unit is Adj { + body ... { + let _ = p.X + p.Y; + } + adjoint self; + } + @EntryPoint() + operation Main() : Unit { + Adjoint Op(new Pair { X = 1, Y = 2 }); + }", + &expect![[r#" + Callable Main: input=Tuple() + Callable Op: input=Tuple(Bind(p_0: Int), Bind(p_1: Int)) + local: Discard(Int)"#]], + ); +} + +#[test] +fn multiple_tuple_params_promotion_behavior() { + // Each tuple-typed parameter is promoted independently when its uses are + // field-only, even when the callable has multiple parameters. + check( + "struct A { X : Int, Y : Int } + struct B { P : Int, Q : Int } + function Add(a : A, b : B) : Int { + a.X + a.Y + b.P + b.Q + } + function Main() : Int { + Add(new A { X = 1, Y = 2 }, new B { P = 3, Q = 4 }) + }", + &expect![[r#" + Callable Add: input=Tuple(Tuple(Bind(a_0: Int), Bind(a_1: Int)), Tuple(Bind(b_0: Int), Bind(b_1: Int))) + Callable Main: input=Tuple()"#]], + ); +} + +#[test] +fn unused_first_class_callable_ref_does_not_block_promotion() { + // The unused `let f = Sum;` no longer survives to arg_promote because the + // preceding defunctionalization stage prunes dead callable-valued locals. + // By the time arg_promote runs, `Sum` is no longer referenced as a live + // first-class value, so its tuple parameter is promoted. + check( + "struct Pair { X : Int, Y : Int } + function Sum(p : Pair) : Int { + p.X + p.Y + } + function Main() : Int { + let p = new Pair { X = 1, Y = 2 }; + let f = Sum; + Sum(p) + }", + &expect![[r#" + Callable Main: input=Tuple() + local: Bind(p: (Int, Int)) + Callable Sum: input=Tuple(Bind(p_0: Int), Bind(p_1: Int))"#]], + ); +} + +#[test] +fn unreachable_partial_application_does_not_block_promotion() { + check( + "struct Pair { X : Int, Y : Int } + operation UsePair(p : Pair, q : Qubit) : Unit { + let _ = p.X + p.Y; + } + operation Unused() : Unit { + use q = Qubit(); + let _f = UsePair(_, q); + } + @EntryPoint() + operation Main() : Unit { + use q = Qubit(); + UsePair(new Pair { X = 1, Y = 2 }, q); + }", + &expect![[r#" + Callable Main: input=Tuple() + local: Bind(q: Qubit) + Callable UsePair: input=Tuple(Tuple(Bind(p_0: Int), Bind(p_1: Int)), Bind(q: Qubit)) + local: Discard(Int)"#]], + ); +} + +#[test] +fn unreachable_first_class_reference_does_not_block_promotion() { + check( + "struct Pair { X : Int, Y : Int } + operation UsePair(p : Pair, q : Qubit) : Unit { + let _ = p.X + p.Y; + } + operation UnusedRef() : Unit { + let f = UsePair; + } + @EntryPoint() + operation Main() : Unit { + use q = Qubit(); + UsePair(new Pair { X = 1, Y = 2 }, q); + }", + &expect![[r#" + Callable Main: input=Tuple() + local: Bind(q: Qubit) + Callable UsePair: input=Tuple(Tuple(Bind(p_0: Int), Bind(p_1: Int)), Bind(q: Qubit)) + local: Discard(Int)"#]], + ); +} + +#[test] +fn controlled_specialization_params_promoted() { + // Operation with Ctl + CtlAdj spec: controlled body should also + // have its parameters promoted when field-only access is used. + check( + "struct Pair { X : Int, Y : Int } + operation Foo(p : Pair) : Unit is Ctl + Adj { + body ... { + let _ = p.X + p.Y; + } + adjoint self; + controlled (cs, ...) { + let _ = p.X + p.Y; + } + controlled adjoint self; + } + @EntryPoint() + operation Main() : Unit { + use q = Qubit(); + Controlled Foo([q], new Pair { X = 3, Y = 4 }); + }", + &expect![[r#" + Callable Foo: input=Tuple(Bind(p_0: Int), Bind(p_1: Int)) + local: Discard(Int) + Callable Main: input=Tuple() + local: Bind(q: Qubit)"#]], + ); +} + +#[test] +fn direct_callable_alias_does_not_block_promotion() { + // A used direct callable alias is rewritten back to the callee before + // arg_promote runs, so the alias itself does not keep the callable from + // having its tuple parameter promoted. + check( + "struct Pair { X : Int, Y : Int } + function UsePair(p : Pair) : Int { + p.X + p.Y + } + function Main() : Int { + let f = UsePair; + f(new Pair { X = 3, Y = 4 }) + }", + &expect![[r#" + Callable Main: input=Tuple() + Callable UsePair: input=Tuple(Bind(p_0: Int), Bind(p_1: Int))"#]], + ); +} + +#[test] +fn promoted_call_sites_keep_targeted_arguments_in_source_order() { + let source = "struct Pair { X : Int, Y : Int } + function Promoted(p : Pair) : Int { + p.X + p.Y + } + function KeepWhole(p : Pair) : Pair { + p + } + function Main() : Int { + let left = new Pair { X = 1, Y = 2 }; + let middle = new Pair { X = 3, Y = 4 }; + let right = KeepWhole(left); + Promoted(middle) + Promoted(right) + }"; + + let (store, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::ArgPromote); + + let result = extract_call_shapes(&store, pkg_id, "Main"); + expect![[r#" + KeepWhole(left) + Promoted((middle.0, middle.1)) + Promoted((right.0, right.1))"#]] + .assert_eq(&result); +} + +#[test] +fn aggregate_argument_expression_is_bound_once_before_field_projection() { + let source = "struct Pair { X : Int, Y : Int } + function BuildPair() : Pair { + new Pair { X = 1, Y = 2 } + } + function Sum(p : Pair) : Int { + p.X + p.Y + } + function Main() : Int { + Sum(BuildPair()) + }"; + + let (store, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::ArgPromote); + let package = store.get(pkg_id); + + let main_block = package.get_block(callable_body_block_id(package, "Main")); + assert_eq!( + main_block.stmts.len(), + 1, + "expected Main to contain one rewritten expression" + ); + + let outer_stmt = package.get_stmt(main_block.stmts[0]); + let StmtKind::Expr(block_expr_id) = &outer_stmt.kind else { + panic!("expected Main body to end with an expression statement"); + }; + + let block_expr = package.get_expr(*block_expr_id); + let ExprKind::Block(rewritten_block_id) = block_expr.kind else { + panic!("expected promoted call to be wrapped in a block"); + }; + + let rewritten_block = package.get_block(rewritten_block_id); + assert_eq!( + rewritten_block.stmts.len(), + 2, + "expected rewritten block to bind the aggregate once and then call Sum" + ); + + let bind_stmt = package.get_stmt(rewritten_block.stmts[0]); + let StmtKind::Local(Mutability::Immutable, temp_pat_id, init_expr_id) = &bind_stmt.kind else { + panic!("expected first rewritten block statement to bind the aggregate argument"); + }; + + let temp_pat = package.get_pat(*temp_pat_id); + let PatKind::Bind(temp_ident) = &temp_pat.kind else { + panic!("expected synthesized binding pattern for aggregate argument"); + }; + expect_direct_item_call(package, *init_expr_id, "BuildPair"); + + let call_stmt = package.get_stmt(rewritten_block.stmts[1]); + let StmtKind::Expr(sum_call_id) = &call_stmt.kind else { + panic!("expected second rewritten block statement to be the promoted call"); + }; + + let promoted_arg_id = expect_direct_item_call(package, *sum_call_id, "Sum"); + let promoted_arg = package.get_expr(promoted_arg_id); + let ExprKind::Tuple(field_expr_ids) = &promoted_arg.kind else { + panic!("expected promoted call argument to be rebuilt as a tuple"); + }; + assert_eq!(field_expr_ids.len(), 2, "expected two projected fields"); + + for (index, field_expr_id) in field_expr_ids.iter().enumerate() { + let field_expr = package.get_expr(*field_expr_id); + let ExprKind::Field(base_expr_id, Field::Path(path)) = &field_expr.kind else { + panic!("expected promoted tuple element to be a field projection"); + }; + let base_expr = package.get_expr(*base_expr_id); + let ExprKind::Var(Res::Local(local_id), _) = &base_expr.kind else { + panic!("expected promoted field projection to read from the synthesized binding"); + }; + assert_eq!(*local_id, temp_ident.id); + assert_eq!(path.indices, vec![index]); + } + + let call_shapes = extract_call_shapes(&store, pkg_id, "Main"); + assert_eq!( + call_shapes + .lines() + .filter(|line| line.starts_with("BuildPair(")) + .count(), + 1, + "expected BuildPair to be evaluated once after promotion:\n{call_shapes}" + ); +} + +#[test] +fn simulatable_intrinsic_tuple_parameter_is_promoted() { + let source = "struct Pair { X : Int, Y : Int } + @SimulatableIntrinsic() + operation MeasurePair(p : Pair) : Int { + p.X + p.Y + } + @EntryPoint() + operation Main() : Int { + let pair = new Pair { X = 1, Y = 2 }; + MeasurePair(pair) + }"; + + let (store, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::ArgPromote); + + expect![[r#" + Callable Main: input=Tuple() + local: Bind(pair: (Int, Int)) + Callable MeasurePair: input=Tuple(Bind(p_0: Int), Bind(p_1: Int))"#]] + .assert_eq(&extract_result(&store, pkg_id)); + + expect![[r#" + MeasurePair((pair.0, pair.1))"#]] + .assert_eq(&extract_call_shapes(&store, pkg_id, "Main")); +} + +#[test] +fn shared_nested_field_aliases_are_rewritten_with_fresh_inner_nodes() { + let source = "struct Inner { A : Int, B : Int } + struct Outer { Left : Inner, Extra : Int } + function Sum(o : Outer) : Int { + o.Left.A + o.Extra + } + function Main() : Int { + Sum(new Outer { Left = new Inner { A = 1, B = 2 }, Extra = 3 }) + }"; + + let (mut store, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::Sroa); + force_shared_nested_field_inner_expr(&mut store, pkg_id, "Sum", "o"); + + let mut assigner = Assigner::from_package(store.get(pkg_id)); + arg_promote(&mut store, pkg_id, &mut assigner); + + let result = extract_field_access_shapes(&store, pkg_id, "Sum"); + assert!( + result.contains("o_0_0.0"), + "expected rewritten field access to target the decomposed inner binding:\n{result}" + ); + assert!( + !result.contains(".0.1"), + "shared ExprId rewrite left a poisoned nested field path:\n{result}" + ); +} + +#[test] +fn closure_targets_are_excluded_from_promotion() { + let source = "struct Pair { X : Int, Y : Int } + function Main() : Int { + let chooser: Pair -> Int = pair -> pair.X + pair.Y; + chooser(new Pair { X = 1, Y = 2 }) + }"; + + let (mut store, pkg_id) = compile_to_fir(source); + let package = store.get(pkg_id); + let reachable = crate::reachability::collect_reachable_from_entry(&store, pkg_id); + let closure_targets = super::collect_closure_targets(package, pkg_id, &reachable); + let mut closure_target_names = closure_targets + .iter() + .map(|item_id| { + let item = package.get_item(*item_id); + let ItemKind::Callable(decl) = &item.kind else { + panic!("closure target should be callable"); + }; + decl.name.name.to_string() + }) + .collect::>(); + closure_target_names.sort(); + assert_eq!(closure_target_names, vec!["".to_string()]); + + let mut assigner = Assigner::from_package(store.get(pkg_id)); + arg_promote(&mut store, pkg_id, &mut assigner); + + let package = store.get(pkg_id); + let lambda = find_callable(package, ""); + let mut binding_names = Vec::new(); + collect_pat_binding_names(package, lambda.input, &mut binding_names); + binding_names.sort(); + assert_eq!(binding_names, vec!["pair".to_string()]); +} + +#[test] +fn arg_promote_is_idempotent() { + let source = "struct Pair { X : Int, Y : Int } + function Foo(p : Pair) : Int { p.X + p.Y } + function Main() : Int { Foo(new Pair { X = 1, Y = 2 }) }"; + let (mut store, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::ArgPromote); + let first = crate::pretty::write_package_qsharp(&store, pkg_id); + let mut assigner = Assigner::from_package(store.get(pkg_id)); + arg_promote(&mut store, pkg_id, &mut assigner); + let second = crate::pretty::write_package_qsharp(&store, pkg_id); + assert_eq!(first, second, "arg_promote should be idempotent"); +} + +#[test] +fn arg_promote_preserves_invariants() { + let source = "struct Pair { X : Int, Y : Int } + function Foo(p : Pair) : Int { p.X + p.Y } + function Main() : Int { Foo(new Pair { X = 1, Y = 2 }) }"; + let (store, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::ArgPromote); + crate::invariants::check( + &store, + pkg_id, + crate::invariants::InvariantLevel::PostArgPromote, + ); +} + +fn render_before_after_arg_promote(source: &str) -> (String, String) { + let (mut store, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::Sroa); + let before = crate::pretty::write_package_qsharp(&store, pkg_id); + let mut assigner = Assigner::from_package(store.get(pkg_id)); + arg_promote(&mut store, pkg_id, &mut assigner); + let after = crate::pretty::write_package_qsharp(&store, pkg_id); + (before, after) +} + +fn check_before_after(source: &str, expect: &Expect) { + let (before, after) = render_before_after_arg_promote(source); + expect.assert_eq(&format!("BEFORE:\n{before}\nAFTER:\n{after}")); +} + +#[test] +fn before_after_param_decomposition() { + check_before_after( + "struct Pair { X : Int, Y : Int } + function Foo(p : Pair) : Int { p.X + p.Y } + function Main() : Int { Foo(new Pair { X = 1, Y = 2 }) }", + &expect![[r#" + BEFORE: + // namespace test + newtype Pair = (Int, Int); + function Foo(p : (Int, Int)) : Int { + body { + p::Item < 0 > + p::Item < 1 > + } + } + function Main() : Int { + body { + Foo(1, 2) + } + } + // entry + Main() + + AFTER: + // namespace test + newtype Pair = (Int, Int); + function Foo(p_0 : Int, p_1 : Int) : Int { + body { + p_0 + p_1 + } + } + function Main() : Int { + body { + Foo(1, 2) + } + } + // entry + Main() + "#]], // snapshot populated by UPDATE_EXPECT=1 + ); +} + +#[test] +fn pretty_print_after_arg_promote_is_non_empty() { + let source = indoc! {r#" + namespace Test { + function Add(pair : (Int, Int)) : Int { + let (a, b) = pair; + a + b + } + + @EntryPoint() + function Main() : Int { + Add((3, 4)) + } + } + "#}; + let (store, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::ArgPromote); + let rendered = crate::pretty::write_package_qsharp(&store, pkg_id); + // After arg_promote the rendered Q# uses flattened parameters and + // `body { ... }` spec syntax which is not valid Q# surface syntax. + // Verify the render at least produces non-empty output. + assert!( + !rendered.is_empty(), + "pretty-printed Q# after arg_promote should not be empty" + ); +} + +#[test] +fn unreachable_caller_call_site_behavior() { + // Dead callable calls a promoted target — document whether it gets rewritten. + // This captures current (package-wide) behavior before scope narrowing. + check( + indoc! {" + namespace Test { + @EntryPoint() + operation Main() : Int { + Foo((1, 2)) + } + operation Foo(x : (Int, Int)) : Int { + let (a, b) = x; + a + b + } + // Dead callable — never called from entry path + operation Dead() : Int { + Foo((3, 4)) + } + } + "}, + &expect![[r#" + Callable Foo: input=Bind(x: (Int, Int)) + local: Tuple(Bind(a: Int), Bind(b: Int)) + Callable Main: input=Tuple()"#]], + ); +} diff --git a/source/compiler/qsc_fir_transforms/src/cloner.rs b/source/compiler/qsc_fir_transforms/src/cloner.rs new file mode 100644 index 0000000000..e79519ecb8 --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/cloner.rs @@ -0,0 +1,754 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! Deep-clone + ID-remap infrastructure for FIR subtrees. +//! +//! [`FirCloner`] copies blocks, expressions, patterns, and statements from a +//! source package into a target package while assigning fresh IDs to every +//! cloned node. All internal references (sub-expression IDs, block IDs, pattern +//! IDs, etc.) are remapped so the cloned subtree is self-consistent and does +//! not collide with existing IDs in the target package. + +#[cfg(test)] +mod tests; + +use qsc_fir::{ + assigner::Assigner, + fir::{ + Block, BlockId, CallableDecl, CallableImpl, ExecGraph, ExecGraphDebugNode, ExecGraphNode, + Expr, ExprId, ExprKind, FieldAssign, Ident, Item, ItemId, ItemKind, LocalItemId, + LocalVarId, NodeId, Package, Pat, PatId, PatKind, Res, SpecDecl, SpecImpl, Stmt, StmtId, + StmtKind, StringComponent, + }, +}; +use rustc_hash::FxHashMap; +use std::rc::Rc; + +/// Deep-clones FIR subtrees with full ID remapping. +/// +/// All package-global IDs (`BlockId`, `ExprId`, `PatId`, `StmtId`, `NodeId`) +/// are replaced with fresh values allocated from the internal `Assigner`. +/// `LocalVarId`s are remapped per-clone to avoid collisions when the cloned +/// body is placed into a different callable scope. +pub struct FirCloner { + /// Assigner for allocating fresh IDs above the target package's maximum. + assigner: Assigner, + /// Old → new remap tables. + block_map: FxHashMap, + expr_map: FxHashMap, + pat_map: FxHashMap, + stmt_map: FxHashMap, + local_map: FxHashMap, + /// Reserved for future use. `NodeId` remapping is currently a no-op + /// delegated to [`Assigner::next_node`]; the field is retained so lookups + /// from `Old` → `New` can be added without changing the public surface. + node_map: FxHashMap, + /// Old → new remap for nested items (`StmtKind::Item` / `ExprKind::Closure`). + item_map: FxHashMap, + /// Per-clone local variable counter. + next_local: u32, + /// Optional remap for self-referencing recursive callables. + /// When set, `Res::Item(old)` matching the first element is remapped to + /// `Res::Item(new)` with the second element. + self_item_remap: Option<(ItemId, ItemId)>, +} + +impl FirCloner { + /// Creates a new cloner whose counters start above the maximum existing IDs + /// in `package`. + #[must_use] + pub fn new(package: &Package) -> Self { + let assigner = Assigner::from_package(package); + Self { + assigner, + block_map: FxHashMap::default(), + expr_map: FxHashMap::default(), + pat_map: FxHashMap::default(), + stmt_map: FxHashMap::default(), + local_map: FxHashMap::default(), + node_map: FxHashMap::default(), + item_map: FxHashMap::default(), + next_local: 0, + self_item_remap: None, + } + } + + /// Creates a new cloner initialized with the provided `Assigner`. + /// + /// Use this when an `Assigner` with correct watermarks is already + /// available (e.g., captured from the lowerer), avoiding the O(n) + /// scan performed by [`FirCloner::new`]. + #[must_use] + pub fn from_assigner(assigner: Assigner) -> Self { + Self { + assigner, + block_map: FxHashMap::default(), + expr_map: FxHashMap::default(), + pat_map: FxHashMap::default(), + stmt_map: FxHashMap::default(), + local_map: FxHashMap::default(), + node_map: FxHashMap::default(), + item_map: FxHashMap::default(), + next_local: 0, + self_item_remap: None, + } + } + + /// Creates a cloner whose `LocalVarId` counter starts at `local_offset`. + /// + /// Use this when inlining a callee body into a caller: set `local_offset` + /// to one past the caller's maximum `LocalVarId` so the inlined locals do + /// not shadow the caller's variables. + #[must_use] + pub fn with_local_offset(package: &Package, local_offset: LocalVarId) -> Self { + let assigner = Assigner::from_package(package); + Self { + assigner, + block_map: FxHashMap::default(), + expr_map: FxHashMap::default(), + pat_map: FxHashMap::default(), + stmt_map: FxHashMap::default(), + local_map: FxHashMap::default(), + node_map: FxHashMap::default(), + item_map: FxHashMap::default(), + next_local: local_offset.into(), + self_item_remap: None, + } + } + + /// Sets the self-item remap so that `Res::Item(old)` references are + /// rewritten to `Res::Item(new)`. Used when cloning a recursive callable + /// to point self-calls at the newly created specialization. + pub fn set_self_item_remap(&mut self, old: ItemId, new: ItemId) { + self.self_item_remap = Some((old, new)); + } + + /// Resets the per-clone remap tables and the local counter. + /// + /// Call this between successive clone operations to start a fresh mapping + /// (e.g., when cloning multiple callables with the same `FirCloner`). + pub fn reset_maps(&mut self) { + self.block_map.clear(); + self.expr_map.clear(); + self.pat_map.clear(); + self.stmt_map.clear(); + self.local_map.clear(); + self.node_map.clear(); + self.item_map.clear(); + self.next_local = 0; + self.self_item_remap = None; + } + + /// Clones all specializations of a `CallableImpl`, inserting cloned nodes + /// into `target`. + pub fn clone_callable_impl( + &mut self, + source: &Package, + callable_impl: &CallableImpl, + target: &mut Package, + ) -> CallableImpl { + match callable_impl { + CallableImpl::Intrinsic => CallableImpl::Intrinsic, + CallableImpl::Spec(spec_impl) => { + CallableImpl::Spec(self.clone_spec_impl(source, spec_impl, target)) + } + CallableImpl::SimulatableIntrinsic(spec_decl) => { + CallableImpl::SimulatableIntrinsic(self.clone_spec_decl(source, spec_decl, target)) + } + } + } + + /// Clones a `SpecImpl` (body + optional adj / ctl / ctl-adj specializations). + pub fn clone_spec_impl( + &mut self, + source: &Package, + spec_impl: &SpecImpl, + target: &mut Package, + ) -> SpecImpl { + let body = self.clone_spec_decl(source, &spec_impl.body, target); + let adj = spec_impl + .adj + .as_ref() + .map(|s| self.clone_spec_decl(source, s, target)); + let ctl = spec_impl + .ctl + .as_ref() + .map(|s| self.clone_spec_decl(source, s, target)); + let ctl_adj = spec_impl + .ctl_adj + .as_ref() + .map(|s| self.clone_spec_decl(source, s, target)); + SpecImpl { + body, + adj, + ctl, + ctl_adj, + } + } + + /// Clones a single `SpecDecl` (one specialization body) into `target`. + pub fn clone_spec_decl( + &mut self, + source: &Package, + spec: &SpecDecl, + target: &mut Package, + ) -> SpecDecl { + let new_node = self.alloc_node(spec.id); + // Clone input BEFORE block so that `local_map` contains input + // parameter mappings when body expressions are walked. + let new_input = spec + .input + .map(|pat_id| self.clone_pat(source, pat_id, target)); + let new_block = self.clone_block(source, spec.block, target); + let new_exec_graph = self.remap_exec_graph(&spec.exec_graph); + SpecDecl { + id: new_node, + span: spec.span, + block: new_block, + input: new_input, + exec_graph: new_exec_graph, + } + } + + /// Clones a block and all its transitive children into `target`. + pub fn clone_block( + &mut self, + source: &Package, + block_id: BlockId, + target: &mut Package, + ) -> BlockId { + if let Some(&mapped) = self.block_map.get(&block_id) { + return mapped; + } + let new_id = self.assigner.next_block(); + self.block_map.insert(block_id, new_id); + + let block = source + .blocks + .get(block_id) + .expect("block should exist in source package"); + let new_stmts: Vec = block + .stmts + .iter() + .map(|&stmt_id| self.clone_stmt(source, stmt_id, target)) + .collect(); + let new_block = Block { + id: new_id, + span: block.span, + ty: block.ty.clone(), + stmts: new_stmts, + }; + target.blocks.insert(new_id, new_block); + new_id + } + + /// Clones a statement into `target`. + pub fn clone_stmt( + &mut self, + source: &Package, + stmt_id: StmtId, + target: &mut Package, + ) -> StmtId { + if let Some(&mapped) = self.stmt_map.get(&stmt_id) { + return mapped; + } + let new_id = self.assigner.next_stmt(); + self.stmt_map.insert(stmt_id, new_id); + + let stmt = source + .stmts + .get(stmt_id) + .expect("stmt should exist in source package"); + let new_kind = match &stmt.kind { + StmtKind::Expr(expr_id) => StmtKind::Expr(self.clone_expr(source, *expr_id, target)), + StmtKind::Semi(expr_id) => StmtKind::Semi(self.clone_expr(source, *expr_id, target)), + StmtKind::Local(mutability, pat_id, expr_id) => StmtKind::Local( + *mutability, + self.clone_pat(source, *pat_id, target), + self.clone_expr(source, *expr_id, target), + ), + StmtKind::Item(item_id) => { + let new_item_id = self.clone_nested_item(source, *item_id, target); + StmtKind::Item(new_item_id) + } + }; + let new_stmt = Stmt { + id: new_id, + span: stmt.span, + kind: new_kind, + exec_graph_range: stmt.exec_graph_range.clone(), + }; + target.stmts.insert(new_id, new_stmt); + new_id + } + + /// Clones a nested item (e.g., from `StmtKind::Item` or `ExprKind::Closure`) + /// into `target`, allocating a fresh `LocalItemId` and remapping its body. + /// + /// Returns the new `LocalItemId` in the target package. + pub fn clone_nested_item( + &mut self, + source: &Package, + item_id: LocalItemId, + target: &mut Package, + ) -> LocalItemId { + // Return existing mapping if already cloned. + if let Some(&mapped) = self.item_map.get(&item_id) { + return mapped; + } + + let new_id = self.alloc_item(); + self.item_map.insert(item_id, new_id); + + let item = source + .items + .get(item_id) + .expect("item should exist in source package"); + + let new_kind = match &item.kind { + ItemKind::Callable(decl) => { + // Save the outer scope's local_map and counter so that the + // nested item's parameters don't overwrite them. LocalVarIds + // are scoped per-callable and commonly reuse the same values + // across different scopes. + let saved_local_map = self.local_map.clone(); + let saved_next_local = self.next_local; + self.local_map = FxHashMap::default(); + self.next_local = 0; + + let new_input = self.clone_pat(source, decl.input, target); + let new_impl = self.clone_callable_impl(source, &decl.implementation, target); + + // Restore the outer scope's local_map and counter. + self.local_map = saved_local_map; + self.next_local = saved_next_local; + + let new_node = self.alloc_node(decl.id); + ItemKind::Callable(Box::new(CallableDecl { + id: new_node, + span: decl.span, + kind: decl.kind, + name: Ident { + id: LocalVarId::default(), + span: decl.name.span, + name: Rc::clone(&decl.name.name), + }, + generics: decl.generics.clone(), + input: new_input, + output: decl.output.clone(), + functors: decl.functors, + implementation: new_impl, + attrs: decl.attrs.clone(), + })) + } + ItemKind::Namespace(ident, items) => ItemKind::Namespace(ident.clone(), items.clone()), + ItemKind::Ty(ident, udt) => ItemKind::Ty(ident.clone(), udt.clone()), + ItemKind::Export(ident, res) => ItemKind::Export(ident.clone(), *res), + }; + + let new_item = Item { + id: new_id, + span: item.span, + parent: item.parent, + doc: Rc::clone(&item.doc), + attrs: item.attrs.clone(), + visibility: item.visibility, + kind: new_kind, + }; + target.items.insert(new_id, new_item); + new_id + } + + /// Clones an expression into `target`, remapping all sub-expression and + /// block references. + pub fn clone_expr( + &mut self, + source: &Package, + expr_id: ExprId, + target: &mut Package, + ) -> ExprId { + if let Some(&mapped) = self.expr_map.get(&expr_id) { + return mapped; + } + let new_id = self.assigner.next_expr(); + self.expr_map.insert(expr_id, new_id); + + let expr = source + .exprs + .get(expr_id) + .expect("expr should exist in source package"); + let new_kind = self.clone_expr_kind(source, &expr.kind, target); + let new_expr = Expr { + id: new_id, + span: expr.span, + ty: expr.ty.clone(), + kind: new_kind, + exec_graph_range: expr.exec_graph_range.clone(), + }; + target.exprs.insert(new_id, new_expr); + new_id + } + + /// Clones a pattern into `target`, remapping `LocalVarId` in bindings. + pub fn clone_pat(&mut self, source: &Package, pat_id: PatId, target: &mut Package) -> PatId { + if let Some(&mapped) = self.pat_map.get(&pat_id) { + return mapped; + } + let new_id = self.assigner.next_pat(); + self.pat_map.insert(pat_id, new_id); + + let pat = source + .pats + .get(pat_id) + .expect("pat should exist in source package"); + let new_kind = match &pat.kind { + PatKind::Bind(ident) => { + let new_local = self.alloc_local(ident.id); + PatKind::Bind(Ident { + id: new_local, + span: ident.span, + name: Rc::clone(&ident.name), + }) + } + PatKind::Discard => PatKind::Discard, + PatKind::Tuple(pats) => { + let new_pats: Vec = pats + .iter() + .map(|&p| self.clone_pat(source, p, target)) + .collect(); + PatKind::Tuple(new_pats) + } + }; + let new_pat = Pat { + id: new_id, + span: pat.span, + ty: pat.ty.clone(), + kind: new_kind, + }; + target.pats.insert(new_id, new_pat); + new_id + } + + /// Clones the input pattern of a callable. This is a convenience that + /// delegates to [`clone_pat`](Self::clone_pat). + pub fn clone_input_pat( + &mut self, + source: &Package, + pat_id: PatId, + target: &mut Package, + ) -> PatId { + self.clone_pat(source, pat_id, target) + } + + /// Remaps a `Res` reference. + /// + /// - `Res::Local(var)` → remapped local + /// - `Res::Item(id)` → remapped only when matching `self_item_remap` + /// - `Res::Err` → unchanged + /// + /// Item references inside [`ExprKind::Closure(_, id)`](ExprKind::Closure) + /// are not routed through this helper. `clone_expr_kind` remaps them + /// through a parallel path: first consulting `item_map`, then falling + /// back to [`clone_nested_item`](Self::clone_nested_item) when the + /// referenced item lives in the source package, and finally consulting + /// `self_item_remap` for the recursive self-item case. Both paths must + /// agree on the resulting `LocalItemId`. + #[must_use] + pub fn remap_res(&self, res: &Res) -> Res { + match res { + Res::Local(var) => Res::Local(*self.local_map.get(var).unwrap_or(var)), + Res::Item(item_id) => { + if let Some((old, new)) = &self.self_item_remap + && item_id == old + { + return Res::Item(*new); + } + Res::Item(*item_id) + } + Res::Err => Res::Err, + } + } + + /// Remaps all typed IDs embedded in an `ExecGraph`. + #[must_use] + pub fn remap_exec_graph(&self, graph: &ExecGraph) -> ExecGraph { + let remap_configured = |nodes: &[ExecGraphNode]| -> Rc<[ExecGraphNode]> { + nodes + .iter() + .map(|node| self.remap_exec_graph_node(*node)) + .collect::>() + .into() + }; + + // ExecGraph stores its fields as Rc<[ExecGraphNode]>. We need to + // extract, remap, and reconstruct. + let no_debug = remap_configured(graph.select_ref(qsc_fir::fir::ExecGraphConfig::NoDebug)); + let debug = remap_configured(graph.select_ref(qsc_fir::fir::ExecGraphConfig::Debug)); + ExecGraph::new(no_debug, debug) + } + + /// Returns a reference to the current block remap table. + #[must_use] + pub fn block_map(&self) -> &FxHashMap { + &self.block_map + } + + /// Returns a reference to the current expression remap table. + #[must_use] + pub fn expr_map(&self) -> &FxHashMap { + &self.expr_map + } + + /// Returns a reference to the current local variable remap table. + #[must_use] + pub fn local_map(&self) -> &FxHashMap { + &self.local_map + } + + /// Returns a reference to the current pattern remap table. + #[must_use] + pub fn pat_map(&self) -> &FxHashMap { + &self.pat_map + } + + /// Returns a reference to the current item remap table. + #[must_use] + pub fn item_map(&self) -> &FxHashMap { + &self.item_map + } + + /// Allocates a fresh `ExprId`. + pub fn alloc_expr(&mut self) -> ExprId { + self.assigner.next_expr() + } + + /// Allocates a fresh `PatId`. + pub fn alloc_pat(&mut self) -> PatId { + self.assigner.next_pat() + } + + /// Allocates a fresh `LocalItemId`. + pub fn alloc_item(&mut self) -> LocalItemId { + self.assigner.next_item() + } + + /// Consumes the cloner and returns the internal `Assigner` with its + /// counters advanced past all IDs allocated during cloning. + #[must_use] + pub fn into_assigner(self) -> Assigner { + self.assigner + } + + fn alloc_node(&mut self, _old: NodeId) -> NodeId { + // `_old` is reserved for future use. Today every cloned node receives + // a fresh id with no lookup against `node_map`; the parameter is kept + // so a remap table can be wired in without changing call sites. + self.assigner.next_node() + } + + pub(crate) fn next_node(&mut self) -> NodeId { + self.assigner.next_node() + } + + pub(crate) fn alloc_local(&mut self, old: LocalVarId) -> LocalVarId { + let new = LocalVarId::from(self.next_local); + self.next_local += 1; + self.local_map.insert(old, new); + new + } + + /// Clones one expression kind into `target`, recursively remapping every + /// referenced child id. + /// + /// Before, `kind` points at blocks, expressions, and patterns owned by the + /// source package. After, the returned `ExprKind` has the same shape but all + /// referenced children have been cloned into `target` and replaced with the + /// freshly allocated ids from this cloner. + #[allow(clippy::too_many_lines)] + fn clone_expr_kind( + &mut self, + source: &Package, + kind: &ExprKind, + target: &mut Package, + ) -> ExprKind { + match kind { + ExprKind::Array(exprs) => ExprKind::Array( + exprs + .iter() + .map(|&e| self.clone_expr(source, e, target)) + .collect(), + ), + ExprKind::ArrayLit(exprs) => ExprKind::ArrayLit( + exprs + .iter() + .map(|&e| self.clone_expr(source, e, target)) + .collect(), + ), + ExprKind::ArrayRepeat(val, size) => ExprKind::ArrayRepeat( + self.clone_expr(source, *val, target), + self.clone_expr(source, *size, target), + ), + ExprKind::Assign(lhs, rhs) => ExprKind::Assign( + self.clone_expr(source, *lhs, target), + self.clone_expr(source, *rhs, target), + ), + ExprKind::AssignOp(op, lhs, rhs) => ExprKind::AssignOp( + *op, + self.clone_expr(source, *lhs, target), + self.clone_expr(source, *rhs, target), + ), + ExprKind::AssignField(record, field, replace) => ExprKind::AssignField( + self.clone_expr(source, *record, target), + field.clone(), + self.clone_expr(source, *replace, target), + ), + ExprKind::AssignIndex(container, index, replace) => ExprKind::AssignIndex( + self.clone_expr(source, *container, target), + self.clone_expr(source, *index, target), + self.clone_expr(source, *replace, target), + ), + ExprKind::BinOp(op, lhs, rhs) => ExprKind::BinOp( + *op, + self.clone_expr(source, *lhs, target), + self.clone_expr(source, *rhs, target), + ), + ExprKind::Block(block_id) => { + ExprKind::Block(self.clone_block(source, *block_id, target)) + } + ExprKind::Call(callee, arg) => ExprKind::Call( + self.clone_expr(source, *callee, target), + self.clone_expr(source, *arg, target), + ), + ExprKind::Closure(vars, local_item_id) => { + let new_vars: Vec = vars + .iter() + .map(|v| *self.local_map.get(v).unwrap_or(v)) + .collect(); + let new_item_id = if let Some(&mapped) = self.item_map.get(local_item_id) { + mapped + } else if source.items.contains_key(*local_item_id) { + self.clone_nested_item(source, *local_item_id, target) + } else if let Some((old, new)) = &self.self_item_remap { + if *local_item_id == old.item { + new.item + } else { + *local_item_id + } + } else { + *local_item_id + }; + ExprKind::Closure(new_vars, new_item_id) + } + ExprKind::Fail(expr) => ExprKind::Fail(self.clone_expr(source, *expr, target)), + ExprKind::Field(expr, field) => { + ExprKind::Field(self.clone_expr(source, *expr, target), field.clone()) + } + ExprKind::Hole => ExprKind::Hole, + ExprKind::If(cond, body, otherwise) => ExprKind::If( + self.clone_expr(source, *cond, target), + self.clone_expr(source, *body, target), + otherwise.map(|e| self.clone_expr(source, e, target)), + ), + ExprKind::Index(array, index) => ExprKind::Index( + self.clone_expr(source, *array, target), + self.clone_expr(source, *index, target), + ), + ExprKind::Lit(lit) => ExprKind::Lit(lit.clone()), + ExprKind::Range(start, step, end) => ExprKind::Range( + start.map(|e| self.clone_expr(source, e, target)), + step.map(|e| self.clone_expr(source, e, target)), + end.map(|e| self.clone_expr(source, e, target)), + ), + ExprKind::Return(expr) => ExprKind::Return(self.clone_expr(source, *expr, target)), + ExprKind::Struct(res, copy, fields) => { + let new_res = self.remap_res(res); + let new_copy = copy.map(|e| self.clone_expr(source, e, target)); + let new_fields: Vec = fields + .iter() + .map(|fa| FieldAssign { + id: self.assigner.next_node(), + span: fa.span, + field: fa.field.clone(), + value: self.clone_expr(source, fa.value, target), + }) + .collect(); + ExprKind::Struct(new_res, new_copy, new_fields) + } + ExprKind::String(components) => { + let new_components: Vec = components + .iter() + .map(|c| match c { + StringComponent::Expr(expr) => { + StringComponent::Expr(self.clone_expr(source, *expr, target)) + } + StringComponent::Lit(s) => StringComponent::Lit(Rc::clone(s)), + }) + .collect(); + ExprKind::String(new_components) + } + ExprKind::UpdateIndex(e1, e2, e3) => ExprKind::UpdateIndex( + self.clone_expr(source, *e1, target), + self.clone_expr(source, *e2, target), + self.clone_expr(source, *e3, target), + ), + ExprKind::Tuple(exprs) => ExprKind::Tuple( + exprs + .iter() + .map(|&e| self.clone_expr(source, e, target)) + .collect(), + ), + ExprKind::UnOp(op, expr) => ExprKind::UnOp(*op, self.clone_expr(source, *expr, target)), + ExprKind::UpdateField(record, field, replace) => ExprKind::UpdateField( + self.clone_expr(source, *record, target), + field.clone(), + self.clone_expr(source, *replace, target), + ), + ExprKind::Var(res, generic_args) => { + ExprKind::Var(self.remap_res(res), generic_args.clone()) + } + ExprKind::While(cond, block) => ExprKind::While( + self.clone_expr(source, *cond, target), + self.clone_block(source, *block, target), + ), + } + } + + fn remap_exec_graph_node(&self, node: ExecGraphNode) -> ExecGraphNode { + match node { + ExecGraphNode::Bind(pat_id) => { + ExecGraphNode::Bind(*self.pat_map.get(&pat_id).unwrap_or(&pat_id)) + } + ExecGraphNode::Expr(expr_id) => { + ExecGraphNode::Expr(*self.expr_map.get(&expr_id).unwrap_or(&expr_id)) + } + // Jump targets are graph-relative indices, not IDs — preserve them. + ExecGraphNode::Jump(_) + | ExecGraphNode::JumpIf(_) + | ExecGraphNode::JumpIfNot(_) + | ExecGraphNode::Store + | ExecGraphNode::Unit + | ExecGraphNode::Ret => node, + ExecGraphNode::Debug(debug_node) => { + ExecGraphNode::Debug(self.remap_debug_node(debug_node)) + } + } + } + + fn remap_debug_node(&self, node: ExecGraphDebugNode) -> ExecGraphDebugNode { + match node { + ExecGraphDebugNode::Stmt(stmt_id) => { + ExecGraphDebugNode::Stmt(*self.stmt_map.get(&stmt_id).unwrap_or(&stmt_id)) + } + ExecGraphDebugNode::PushLoopScope(expr_id) => { + ExecGraphDebugNode::PushLoopScope(*self.expr_map.get(&expr_id).unwrap_or(&expr_id)) + } + ExecGraphDebugNode::BlockEnd(block_id) => { + ExecGraphDebugNode::BlockEnd(*self.block_map.get(&block_id).unwrap_or(&block_id)) + } + ExecGraphDebugNode::PushScope + | ExecGraphDebugNode::PopScope + | ExecGraphDebugNode::RetFrame + | ExecGraphDebugNode::LoopIteration => node, + } + } +} diff --git a/source/compiler/qsc_fir_transforms/src/cloner/tests.rs b/source/compiler/qsc_fir_transforms/src/cloner/tests.rs new file mode 100644 index 0000000000..fb36853132 --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/cloner/tests.rs @@ -0,0 +1,775 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use super::*; +use qsc_data_structures::{index_map::IndexMap, span::Span}; +use qsc_fir::fir::{ + Block, BlockId, CallableDecl, CallableImpl, CallableKind, ExecGraph, Expr, ExprId, ExprKind, + Item, LocalItemId, Mutability, NodeId, Pat, PatId, PatKind, SpecDecl, SpecImpl, Stmt, StmtId, + StmtKind, Visibility, +}; +use qsc_fir::ty::{Arrow, FunctorSet, FunctorSetValue, Prim, Ty}; +use std::rc::Rc; + +fn empty_exec_graph_range() -> std::ops::Range { + let zero = qsc_fir::fir::ExecGraphIdx { + no_debug_idx: 0, + debug_idx: 0, + }; + zero..zero +} + +/// Creates a minimal package with a single callable body for testing. +#[allow(clippy::similar_names)] +fn make_test_package() -> Package { + let mut blocks: IndexMap = IndexMap::new(); + let mut exprs: IndexMap = IndexMap::new(); + let mut pats: IndexMap = IndexMap::new(); + let mut stmts: IndexMap = IndexMap::new(); + + // Pat 0: Bind(x) with LocalVarId 0 + let pat0 = Pat { + id: PatId::from(0u32), + span: Span::default(), + ty: Ty::Prim(qsc_fir::ty::Prim::Int), + kind: PatKind::Bind(Ident { + id: LocalVarId::from(0u32), + span: Span::default(), + name: "x".into(), + }), + }; + pats.insert(PatId::from(0u32), pat0); + + // Expr 0: Var(Local(0)) — reference to x + let expr0 = Expr { + id: ExprId::from(0u32), + span: Span::default(), + ty: Ty::Prim(qsc_fir::ty::Prim::Int), + kind: ExprKind::Var(Res::Local(LocalVarId::from(0u32)), vec![]), + exec_graph_range: empty_exec_graph_range(), + }; + exprs.insert(ExprId::from(0u32), expr0); + + // Expr 1: Lit(Int(42)) + let expr1 = Expr { + id: ExprId::from(1u32), + span: Span::default(), + ty: Ty::Prim(qsc_fir::ty::Prim::Int), + kind: ExprKind::Lit(qsc_fir::fir::Lit::Int(42)), + exec_graph_range: empty_exec_graph_range(), + }; + exprs.insert(ExprId::from(1u32), expr1); + + // Stmt 0: Local(Immutable, Pat 0, Expr 1) — let x = 42; + let stmt0 = Stmt { + id: StmtId::from(0u32), + span: Span::default(), + kind: StmtKind::Local(Mutability::Immutable, PatId::from(0u32), ExprId::from(1u32)), + exec_graph_range: empty_exec_graph_range(), + }; + stmts.insert(StmtId::from(0u32), stmt0); + + // Stmt 1: Expr(Expr 0) — x (tail expression) + let stmt1 = Stmt { + id: StmtId::from(1u32), + span: Span::default(), + kind: StmtKind::Expr(ExprId::from(0u32)), + exec_graph_range: empty_exec_graph_range(), + }; + stmts.insert(StmtId::from(1u32), stmt1); + + // Block 0: [Stmt 0, Stmt 1] + let block0 = Block { + id: BlockId::from(0u32), + span: Span::default(), + ty: Ty::Prim(qsc_fir::ty::Prim::Int), + stmts: vec![StmtId::from(0u32), StmtId::from(1u32)], + }; + blocks.insert(BlockId::from(0u32), block0); + + Package { + items: IndexMap::new(), + entry: None, + entry_exec_graph: ExecGraph::default(), + blocks, + exprs, + pats, + stmts, + } +} + +#[test] +fn clone_block_produces_fresh_ids() { + let source = make_test_package(); + let mut target = make_test_package(); + let mut cloner = FirCloner::new(&target); + + let new_block_id = cloner.clone_block(&source, BlockId::from(0u32), &mut target); + + // New block ID must differ from original. + assert_ne!(u32::from(new_block_id), 0); + + // Target must contain the new block. + assert!(target.blocks.get(new_block_id).is_some()); + + // New block should have the same number of stmts. + let new_block = target.blocks.get(new_block_id).expect("block not found"); + assert_eq!(new_block.stmts.len(), 2); + + // All new stmt IDs should be > the original max (1). + for &stmt_id in &new_block.stmts { + assert!(u32::from(stmt_id) > 1); + } +} + +#[test] +fn clone_pat_remaps_local_var_id() { + let source = make_test_package(); + let mut target = make_test_package(); + // Use local_offset > 0 to simulate inlining into a caller that + // already uses locals 0..N. + let mut cloner = FirCloner::with_local_offset(&target, LocalVarId::from(10u32)); + + let new_pat_id = cloner.clone_pat(&source, PatId::from(0u32), &mut target); + let new_pat = target.pats.get(new_pat_id).expect("pat not found"); + + // The cloned pattern's Bind should have a fresh LocalVarId starting at 10. + if let PatKind::Bind(ident) = &new_pat.kind { + assert_eq!(ident.id, LocalVarId::from(10u32)); + } else { + panic!("expected PatKind::Bind"); + } +} + +#[test] +fn clone_pat_mono_local_starts_at_zero() { + let source = make_test_package(); + let mut target = make_test_package(); + let mut cloner = FirCloner::new(&target); + + let new_pat_id = cloner.clone_pat(&source, PatId::from(0u32), &mut target); + let new_pat = target.pats.get(new_pat_id).expect("pat not found"); + + // For monomorphization, locals start at 0 (new callable scope). + if let PatKind::Bind(ident) = &new_pat.kind { + assert_eq!(ident.id, LocalVarId::from(0u32)); + // But the local_map should have recorded the mapping. + assert!(cloner.local_map().contains_key(&LocalVarId::from(0u32))); + } else { + panic!("expected PatKind::Bind"); + } +} + +#[test] +fn clone_expr_remaps_local_res() { + let source = make_test_package(); + let mut target = make_test_package(); + // Use offset to ensure locals are remapped to distinct values. + let mut cloner = FirCloner::with_local_offset(&target, LocalVarId::from(10u32)); + + // Clone the pat first so that the local mapping is established. + let _new_pat = cloner.clone_pat(&source, PatId::from(0u32), &mut target); + let new_expr_id = cloner.clone_expr(&source, ExprId::from(0u32), &mut target); + let new_expr = target.exprs.get(new_expr_id).expect("expr not found"); + + if let ExprKind::Var(Res::Local(var), _) = &new_expr.kind { + // The local ref should be remapped to the offset value. + assert_eq!(*var, LocalVarId::from(10u32)); + } else { + panic!("expected ExprKind::Var(Res::Local(_))"); + } +} + +#[test] +fn clone_preserves_cross_package_res() { + let target = make_test_package(); + let cloner = FirCloner::new(&target); + + // Manually insert an expr that references a cross-package item. + let cross_pkg_item = ItemId { + package: qsc_fir::fir::PackageId::CORE, + item: LocalItemId::from(5usize), + }; + let cross_res = Res::Item(cross_pkg_item); + let remapped = cloner.remap_res(&cross_res); + assert_eq!(remapped, cross_res); +} + +#[test] +fn self_item_remap_rewrites_item_resource() { + let target = make_test_package(); + let mut cloner = FirCloner::new(&target); + + let old_item = ItemId { + package: qsc_fir::fir::PackageId::from(2usize), + item: LocalItemId::from(10usize), + }; + let new_item = ItemId { + package: qsc_fir::fir::PackageId::from(2usize), + item: LocalItemId::from(20usize), + }; + cloner.set_self_item_remap(old_item, new_item); + + let remapped = cloner.remap_res(&Res::Item(old_item)); + assert_eq!(remapped, Res::Item(new_item)); + + // Other items should not be affected. + let other_item = ItemId { + package: qsc_fir::fir::PackageId::from(2usize), + item: LocalItemId::from(11usize), + }; + let remapped_other = cloner.remap_res(&Res::Item(other_item)); + assert_eq!(remapped_other, Res::Item(other_item)); +} + +#[test] +fn clone_closure_with_captures_remaps_local_ids() { + let mut source = make_test_package(); + let mut target = make_test_package(); + + // Add a second local binding: Pat 1: Bind(y) with LocalVarId 1 + let pat1 = Pat { + id: PatId::from(1u32), + span: Span::default(), + ty: Ty::Prim(Prim::Int), + kind: PatKind::Bind(Ident { + id: LocalVarId::from(1u32), + span: Span::default(), + name: "y".into(), + }), + }; + source.pats.insert(PatId::from(1u32), pat1.clone()); + target.pats.insert(PatId::from(1u32), pat1); + + // Expr 2: Lit(Int(10)) — initializer for y + let expr2 = Expr { + id: ExprId::from(2u32), + span: Span::default(), + ty: Ty::Prim(Prim::Int), + kind: ExprKind::Lit(qsc_fir::fir::Lit::Int(10)), + exec_graph_range: empty_exec_graph_range(), + }; + source.exprs.insert(ExprId::from(2u32), expr2.clone()); + target.exprs.insert(ExprId::from(2u32), expr2); + + // Stmt 2: Local(Immutable, Pat 1, Expr 2) — let y = 10; + let stmt2 = Stmt { + id: StmtId::from(2u32), + span: Span::default(), + kind: StmtKind::Local(Mutability::Immutable, PatId::from(1u32), ExprId::from(2u32)), + exec_graph_range: empty_exec_graph_range(), + }; + source.stmts.insert(StmtId::from(2u32), stmt2.clone()); + target.stmts.insert(StmtId::from(2u32), stmt2); + + // Expr 3: Closure capturing [LocalVarId(0), LocalVarId(1)], targeting LocalItemId(0) + let expr3 = Expr { + id: ExprId::from(3u32), + span: Span::default(), + ty: Ty::Arrow(Box::new(Arrow { + kind: CallableKind::Function, + input: Box::new(Ty::Prim(Prim::Int)), + output: Box::new(Ty::Prim(Prim::Int)), + functors: FunctorSet::Value(FunctorSetValue::Empty), + })), + kind: ExprKind::Closure( + vec![LocalVarId::from(0u32), LocalVarId::from(1u32)], + LocalItemId::from(0usize), + ), + exec_graph_range: empty_exec_graph_range(), + }; + source.exprs.insert(ExprId::from(3u32), expr3.clone()); + target.exprs.insert(ExprId::from(3u32), expr3); + + // Use offset so locals are remapped to distinct values. + let mut cloner = FirCloner::with_local_offset(&target, LocalVarId::from(10u32)); + + // Clone the patterns first to establish the local mappings. + let _new_pat0 = cloner.clone_pat(&source, PatId::from(0u32), &mut target); + let _new_pat1 = cloner.clone_pat(&source, PatId::from(1u32), &mut target); + + // Clone the closure expression. + let new_expr_id = cloner.clone_expr(&source, ExprId::from(3u32), &mut target); + let new_expr = target.exprs.get(new_expr_id).expect("expr not found"); + + // Verify captures are remapped. + if let ExprKind::Closure(captures, _item_id) = &new_expr.kind { + assert_eq!(captures.len(), 2); + assert_eq!(captures[0], LocalVarId::from(10u32)); + assert_eq!(captures[1], LocalVarId::from(11u32)); + } else { + panic!("expected ExprKind::Closure"); + } + + // Verify the expression type is preserved as Arrow. + assert!(matches!(&new_expr.ty, Ty::Arrow(_))); +} + +#[test] +#[allow(clippy::similar_names)] +#[allow(clippy::too_many_lines)] +fn clone_nested_item_isolates_local_scope() { + let mut source = make_test_package(); + let mut target = make_test_package(); + + // Build a nested callable item (inner function) with its own local binding. + // Inner function body: let z = 99; z + + // Pat 2: Bind(z) with LocalVarId 0 (same as outer — scoped per-callable) + let inner_pat = Pat { + id: PatId::from(2u32), + span: Span::default(), + ty: Ty::Prim(Prim::Int), + kind: PatKind::Bind(Ident { + id: LocalVarId::from(0u32), + span: Span::default(), + name: "z".into(), + }), + }; + source.pats.insert(PatId::from(2u32), inner_pat.clone()); + target.pats.insert(PatId::from(2u32), inner_pat); + + // Pat 3: Discard — inner function input pattern (no parameters) + let inner_input_pat = Pat { + id: PatId::from(3u32), + span: Span::default(), + ty: Ty::UNIT, + kind: PatKind::Discard, + }; + source + .pats + .insert(PatId::from(3u32), inner_input_pat.clone()); + target.pats.insert(PatId::from(3u32), inner_input_pat); + + // Expr 2: Lit(Int(99)) + let inner_init = Expr { + id: ExprId::from(2u32), + span: Span::default(), + ty: Ty::Prim(Prim::Int), + kind: ExprKind::Lit(qsc_fir::fir::Lit::Int(99)), + exec_graph_range: empty_exec_graph_range(), + }; + source.exprs.insert(ExprId::from(2u32), inner_init.clone()); + target.exprs.insert(ExprId::from(2u32), inner_init); + + // Expr 3: Var(Local(0)) — reference to z + let inner_var = Expr { + id: ExprId::from(3u32), + span: Span::default(), + ty: Ty::Prim(Prim::Int), + kind: ExprKind::Var(Res::Local(LocalVarId::from(0u32)), vec![]), + exec_graph_range: empty_exec_graph_range(), + }; + source.exprs.insert(ExprId::from(3u32), inner_var.clone()); + target.exprs.insert(ExprId::from(3u32), inner_var); + + // Stmt 2: Local(Immutable, Pat 2, Expr 2) + let inner_let = Stmt { + id: StmtId::from(2u32), + span: Span::default(), + kind: StmtKind::Local(Mutability::Immutable, PatId::from(2u32), ExprId::from(2u32)), + exec_graph_range: empty_exec_graph_range(), + }; + source.stmts.insert(StmtId::from(2u32), inner_let.clone()); + target.stmts.insert(StmtId::from(2u32), inner_let); + + // Stmt 3: Expr(Expr 3) + let inner_tail = Stmt { + id: StmtId::from(3u32), + span: Span::default(), + kind: StmtKind::Expr(ExprId::from(3u32)), + exec_graph_range: empty_exec_graph_range(), + }; + source.stmts.insert(StmtId::from(3u32), inner_tail.clone()); + target.stmts.insert(StmtId::from(3u32), inner_tail); + + // Block 1: inner function body [Stmt 2, Stmt 3] + let inner_block = Block { + id: BlockId::from(1u32), + span: Span::default(), + ty: Ty::Prim(Prim::Int), + stmts: vec![StmtId::from(2u32), StmtId::from(3u32)], + }; + source + .blocks + .insert(BlockId::from(1u32), inner_block.clone()); + target.blocks.insert(BlockId::from(1u32), inner_block); + + // Item 0: Callable (inner function) + let inner_callable = Item { + id: LocalItemId::from(0usize), + span: Span::default(), + parent: None, + doc: Rc::from(""), + attrs: vec![], + visibility: Visibility::Public, + kind: ItemKind::Callable(Box::new(CallableDecl { + id: NodeId::from(0u32), + span: Span::default(), + kind: CallableKind::Function, + name: Ident { + id: LocalVarId::default(), + span: Span::default(), + name: "inner".into(), + }, + generics: vec![], + input: PatId::from(3u32), + output: Ty::Prim(Prim::Int), + functors: FunctorSetValue::Empty, + implementation: CallableImpl::Spec(SpecImpl { + body: SpecDecl { + id: NodeId::from(1u32), + span: Span::default(), + block: BlockId::from(1u32), + input: None, + exec_graph: ExecGraph::default(), + }, + adj: None, + ctl: None, + ctl_adj: None, + }), + attrs: vec![], + })), + }; + source + .items + .insert(LocalItemId::from(0usize), inner_callable.clone()); + target + .items + .insert(LocalItemId::from(0usize), inner_callable); + + // Add StmtKind::Item to the outer block so the nested item is reachable. + let stmt_item = Stmt { + id: StmtId::from(4u32), + span: Span::default(), + kind: StmtKind::Item(LocalItemId::from(0usize)), + exec_graph_range: empty_exec_graph_range(), + }; + source.stmts.insert(StmtId::from(4u32), stmt_item.clone()); + target.stmts.insert(StmtId::from(4u32), stmt_item); + + // Add Stmt 4 to Block 0 (the outer block) + source + .blocks + .get_mut(BlockId::from(0u32)) + .expect("block 0") + .stmts + .push(StmtId::from(4u32)); + target + .blocks + .get_mut(BlockId::from(0u32)) + .expect("block 0") + .stmts + .push(StmtId::from(4u32)); + + // Clone with an offset so outer locals are remapped to 10+. + let mut cloner = FirCloner::with_local_offset(&target, LocalVarId::from(10u32)); + + // Clone the outer block (which includes the nested item via StmtKind::Item). + let _outer_pat = cloner.clone_pat(&source, PatId::from(0u32), &mut target); + let new_block_id = cloner.clone_block(&source, BlockId::from(0u32), &mut target); + + // Verify the outer local was remapped to 10. + let outer_new_local = cloner.local_map().get(&LocalVarId::from(0u32)).copied(); + assert_eq!( + outer_new_local, + Some(LocalVarId::from(10u32)), + "outer local should be remapped to offset 10" + ); + + // Verify that the nested item was cloned (should appear in item_map). + assert!( + !cloner.item_map().is_empty(), + "nested item should have been cloned" + ); + + // Verify the cloned nested item's inner local bindings are independent: + // The inner callable resets its locals to 0, so the inner Pat Bind(z) + // should have LocalVarId(0), not 10+. + let new_inner_item_id = cloner + .item_map() + .get(&LocalItemId::from(0usize)) + .expect("inner item should be in item_map"); + let new_inner_item = target + .items + .get(*new_inner_item_id) + .expect("cloned inner item should exist"); + if let ItemKind::Callable(decl) = &new_inner_item.kind { + if let CallableImpl::Spec(spec_impl) = &decl.implementation { + let inner_block = target + .blocks + .get(spec_impl.body.block) + .expect("inner block"); + // The first statement is the local binding with the inner Pat. + let first_stmt = target.stmts.get(inner_block.stmts[0]).expect("inner stmt"); + if let StmtKind::Local(_, pat_id, _) = &first_stmt.kind { + let inner_pat = target.pats.get(*pat_id).expect("inner pat"); + if let PatKind::Bind(ident) = &inner_pat.kind { + // Inner callable's locals start fresh at 0. + assert_eq!( + ident.id, + LocalVarId::from(0u32), + "inner callable's local should start at 0, not inherit outer offset" + ); + } else { + panic!("expected PatKind::Bind on inner local"); + } + } else { + panic!("expected StmtKind::Local as first inner stmt"); + } + } else { + panic!("expected CallableImpl::Spec"); + } + } else { + panic!("expected ItemKind::Callable"); + } + + // Verify the outer block's cloned stmts exist. + let new_block = target.blocks.get(new_block_id).expect("new outer block"); + assert_eq!(new_block.stmts.len(), 3, "outer block should have 3 stmts"); +} + +/// Creates an empty package for use as a clone target. +fn empty_package() -> Package { + Package { + items: IndexMap::new(), + entry: None, + entry_exec_graph: ExecGraph::default(), + blocks: IndexMap::new(), + exprs: IndexMap::new(), + pats: IndexMap::new(), + stmts: IndexMap::new(), + } +} + +// ── Idempotency tests ── + +#[test] +fn clone_block_is_idempotent() { + let source = make_test_package(); + + // First clone: source → target1. + let mut target1 = empty_package(); + let mut cloner1 = FirCloner::new(&target1); + let block1_id = cloner1.clone_block(&source, BlockId::from(0u32), &mut target1); + + // Second clone: target1 → target2. + let mut target2 = empty_package(); + let mut cloner2 = FirCloner::new(&target2); + let block2_id = cloner2.clone_block(&target1, block1_id, &mut target2); + + let block1 = target1.blocks.get(block1_id).expect("block1"); + let block2 = target2.blocks.get(block2_id).expect("block2"); + + assert_eq!(block1.stmts.len(), block2.stmts.len()); + assert_eq!(block1.ty, block2.ty); + + // Statement kind discriminants must match. + for (&s1, &s2) in block1.stmts.iter().zip(block2.stmts.iter()) { + let stmt1 = target1.stmts.get(s1).expect("stmt1"); + let stmt2 = target2.stmts.get(s2).expect("stmt2"); + assert_eq!( + std::mem::discriminant(&stmt1.kind), + std::mem::discriminant(&stmt2.kind), + ); + } + + // Element counts must match across both clones. + assert_eq!(target1.exprs.iter().count(), target2.exprs.iter().count()); + assert_eq!(target1.pats.iter().count(), target2.pats.iter().count()); + assert_eq!(target1.stmts.iter().count(), target2.stmts.iter().count()); +} + +#[test] +fn clone_expr_is_idempotent_for_literal() { + let source = make_test_package(); + + // First clone of Expr 1: Lit(Int(42)). + let mut target1 = empty_package(); + let mut cloner1 = FirCloner::new(&target1); + let expr1_id = cloner1.clone_expr(&source, ExprId::from(1u32), &mut target1); + + // Second clone from target1. + let mut target2 = empty_package(); + let mut cloner2 = FirCloner::new(&target2); + let expr2_id = cloner2.clone_expr(&target1, expr1_id, &mut target2); + + let expr1 = target1.exprs.get(expr1_id).expect("expr1"); + let expr2 = target2.exprs.get(expr2_id).expect("expr2"); + + assert_eq!(expr1.ty, expr2.ty); + match (&expr1.kind, &expr2.kind) { + (ExprKind::Lit(qsc_fir::fir::Lit::Int(v1)), ExprKind::Lit(qsc_fir::fir::Lit::Int(v2))) => { + assert_eq!(v1, v2, "literal value must survive double-clone"); + } + _ => panic!("expected Lit(Int) on both clones"), + } +} + +// ── Type preservation and structural assertion tests ── + +#[test] +fn clone_preserves_expression_types() { + let source = make_test_package(); + let mut target = empty_package(); + let mut cloner = FirCloner::new(&target); + cloner.clone_block(&source, BlockId::from(0u32), &mut target); + + assert_eq!( + target.exprs.iter().count(), + source.exprs.iter().count(), + "expression count must match" + ); + + let mut source_types: Vec = source + .exprs + .iter() + .map(|(_, e)| format!("{:?}", e.ty)) + .collect(); + let mut target_types: Vec = target + .exprs + .iter() + .map(|(_, e)| format!("{:?}", e.ty)) + .collect(); + source_types.sort(); + target_types.sort(); + assert_eq!(source_types, target_types, "expression types must match"); +} + +#[test] +fn clone_preserves_pattern_types_and_kinds() { + let source = make_test_package(); + let mut target = empty_package(); + let mut cloner = FirCloner::new(&target); + cloner.clone_block(&source, BlockId::from(0u32), &mut target); + + assert_eq!( + target.pats.iter().count(), + source.pats.iter().count(), + "pattern count must match" + ); + + let mut source_types: Vec = source + .pats + .iter() + .map(|(_, p)| format!("{:?}", p.ty)) + .collect(); + let mut target_types: Vec = target + .pats + .iter() + .map(|(_, p)| format!("{:?}", p.ty)) + .collect(); + source_types.sort(); + target_types.sort(); + assert_eq!(source_types, target_types, "pattern types must match"); + + let source_bind_count = source + .pats + .iter() + .filter(|(_, p)| matches!(p.kind, PatKind::Bind(_))) + .count(); + let target_bind_count = target + .pats + .iter() + .filter(|(_, p)| matches!(p.kind, PatKind::Bind(_))) + .count(); + assert_eq!( + source_bind_count, target_bind_count, + "bind pattern count must match" + ); +} + +#[test] +#[allow(clippy::similar_names)] +fn clone_nested_item_preserves_callable_signature() { + let mut source = make_test_package(); + + // Add a Discard input pattern for the callable. + let input_pat = Pat { + id: PatId::from(2u32), + span: Span::default(), + ty: Ty::UNIT, + kind: PatKind::Discard, + }; + source.pats.insert(PatId::from(2u32), input_pat); + + // Add a callable item using block 0 as its body. + let item = Item { + id: LocalItemId::from(0usize), + span: Span::default(), + parent: None, + doc: Rc::from(""), + attrs: vec![], + visibility: Visibility::Public, + kind: ItemKind::Callable(Box::new(CallableDecl { + id: NodeId::from(10u32), + span: Span::default(), + kind: CallableKind::Function, + name: Ident { + id: LocalVarId::default(), + span: Span::default(), + name: "test_fn".into(), + }, + generics: vec![], + input: PatId::from(2u32), + output: Ty::Prim(Prim::Int), + functors: FunctorSetValue::Empty, + implementation: CallableImpl::Spec(SpecImpl { + body: SpecDecl { + id: NodeId::from(11u32), + span: Span::default(), + block: BlockId::from(0u32), + input: None, + exec_graph: ExecGraph::default(), + }, + adj: None, + ctl: None, + ctl_adj: None, + }), + attrs: vec![], + })), + }; + source.items.insert(LocalItemId::from(0usize), item); + + let mut target = empty_package(); + let mut cloner = FirCloner::new(&target); + let new_item_id = cloner.clone_nested_item(&source, LocalItemId::from(0usize), &mut target); + + assert_eq!(target.items.iter().count(), 1, "cloned item count"); + + let orig = source + .items + .get(LocalItemId::from(0usize)) + .expect("source item"); + let cloned = target.items.get(new_item_id).expect("cloned item"); + + if let (ItemKind::Callable(orig_decl), ItemKind::Callable(new_decl)) = + (&orig.kind, &cloned.kind) + { + assert_eq!(orig_decl.kind, new_decl.kind, "callable kind"); + assert_eq!(orig_decl.output, new_decl.output, "return type"); + assert_eq!(orig_decl.functors, new_decl.functors, "functors"); + assert_eq!( + orig_decl.generics.len(), + new_decl.generics.len(), + "generics count" + ); + + // Verify the body block was cloned with matching stmt count. + if let (CallableImpl::Spec(orig_spec), CallableImpl::Spec(new_spec)) = + (&orig_decl.implementation, &new_decl.implementation) + { + let orig_block = source.blocks.get(orig_spec.body.block).expect("orig block"); + let new_block = target.blocks.get(new_spec.body.block).expect("new block"); + assert_eq!( + orig_block.stmts.len(), + new_block.stmts.len(), + "body stmt count" + ); + } else { + panic!("expected CallableImpl::Spec on both"); + } + } else { + panic!("expected ItemKind::Callable on both"); + } +} diff --git a/source/compiler/qsc_fir_transforms/src/defunctionalize.rs b/source/compiler/qsc_fir_transforms/src/defunctionalize.rs new file mode 100644 index 0000000000..196acddcf7 --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/defunctionalize.rs @@ -0,0 +1,630 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! Defunctionalization pass. +//! +//! Eliminates all callable-valued expressions (arrow-typed locals, closures, +//! functor-applied callable values) in entry-reachable code through a +//! dispatch-free specialization approach. Unlike classical defunctionalization +//! (which introduces a tagged union and an `apply` function), this +//! implementation directly specializes each higher-order function (HOF) call +//! site where a concrete callable argument is known at compile time. +//! Single-bound tuple parameters whose type contains callable values are +//! supported via a split locator model that tracks the top-level parameter +//! slot separately from the nested tuple field path. +//! +//! Establishes [`crate::invariants::InvariantLevel::PostDefunc`]: no +//! `ExprKind::Closure` remains in reachable code, no arrow-typed callable +//! parameters remain in reachable declarations, and all indirect dispatch +//! is rewritten to direct dispatch. +//! +//! Each iteration of the fixpoint loop consists of three phases: +//! +//! - **Analysis** — discovers callable-typed parameters in HOFs, collects +//! call sites where those HOFs are invoked with concrete callable arguments, +//! and runs an identity-closure peephole optimization that replaces +//! `(args) => f(args)` wrappers with direct references to `f`. +//! - **Specialization** — clones each HOF for each concrete argument +//! combination, replacing the callable parameter reference with a direct +//! call to the concrete callee. A deduplication map keyed by [`types::SpecKey`] +//! ensures identical specializations are created only once. +//! - **Rewrite** — redirects original call sites to invoke the specialized +//! clones, removes the callable argument from the argument tuple, and +//! threads closure captures as extra arguments. +//! +//! These phases iterate until no reachable closures or arrow-typed parameters +//! remain in the target package. The iteration limit is dynamically scaled on +//! the first pass based on the number of discovered callable values +//! (`remaining_count.clamp(5, 20)`), preventing unnecessary iterations for +//! simple programs while allowing complex HOF nesting patterns to resolve. If +//! convergence is not reached within the limit, an error is reported. +//! +//! In the future, this pass could be extended to support tagged-union-style +//! defunctionalization for cases where specialization does not converge, +//! but the current approach is required for QIR generation because the QIR +//! specification requires direct calls to known callees. +//! +//! # Input patterns +//! +//! - `operation Apply(op : Qubit => Unit, q : Qubit) { op(q); }` — an arrow +//! parameter consumed by a HOF. +//! - `Apply(H, q)` — a call site binding the arrow parameter to a concrete +//! global callable. +//! - `Apply(q => Y(q), q)` — a call site binding the arrow parameter to a +//! lambda. +//! +//! # Rewrites +//! +//! ```text +//! // Before +//! operation Apply(op : Qubit => Unit, q : Qubit) { op(q); } +//! Apply(q => Y(q), target); +//! +//! // After (closure identity peephole collapses the lambda to `Y`) +//! operation Apply_specialized_Y(q : Qubit) { Y(q); } +//! Apply_specialized_Y(target); +//! ``` +//! +//! # Notes +//! +//! - Synthesized expressions use `EMPTY_EXEC_RANGE`; the +//! [`crate::exec_graph_rebuild`] pass repairs them at the end of the +//! pipeline. + +mod analysis; +mod prepass; +mod rewrite; +mod specialize; +pub mod types; + +pub use types::Error; + +#[cfg(test)] +mod tests; + +#[cfg(all(test, feature = "slow-proptest-tests"))] +mod semantic_equivalence_tests; + +use crate::fir_builder::reachable_local_callables; +use crate::reachability::collect_reachable_from_entry; +use crate::walk_utils::collect_expr_ids_in_entry_and_local_callables; +use qsc_data_structures::functors::FunctorApp; +use qsc_data_structures::span::Span; +use qsc_fir::assigner::Assigner; +use qsc_fir::fir::{ + ExprId, ExprKind, ItemKind, LocalItemId, Package, PackageId, PackageLookup, PackageStore, Res, + StoreItemId, +}; +use qsc_fir::ty::Ty; +use rustc_hash::{FxHashMap, FxHashSet}; +use types::{ + AnalysisResult, CallSite, CallableParam, ConcreteCallable, ConcreteCallableKey, SpecKey, + peel_body_functors, +}; + +/// Maximum number of analysis → specialize → rewrite iterations before +/// reporting a convergence failure. +/// +/// The value of 5 is the floor: after the first iteration the limit is +/// recomputed as +/// `max(callable_params.len(), remaining_count).clamp(MAX_ITERATIONS, 20)`, +/// giving one iteration of margin beyond the deepest observed HOF chain +/// (4 levels in the chemistry library's Trotter simulation pipeline) and +/// an upper bound of 20 iterations for pathological programs. +const MAX_ITERATIONS: usize = 5; + +/// Defunctionalizes all callable-valued expressions in the entry-reachable +/// portion of a package. +/// +/// After this pass: +/// - No `ExprKind::Closure` nodes remain in reachable code. +/// - No arrow-typed parameters remain in reachable callable declarations. +/// - All indirect callable dispatch is replaced with direct dispatch calls. +/// +/// Returns a vector of errors encountered during defunctionalization. +/// An empty vector indicates success. +pub fn defunctionalize( + store: &mut PackageStore, + package_id: PackageId, + assigner: &mut Assigner, +) -> Vec { + let package = store.get(package_id); + if package.entry.is_none() { + return vec![]; + } + + let mut errors: Vec = Vec::new(); + let mut warnings: Vec = Vec::new(); + let mut max_iterations = MAX_ITERATIONS; + let mut iteration_count = 0; + let mut specialized_closure_targets: FxHashSet = FxHashSet::default(); + let mut specialized_items: FxHashSet = FxHashSet::default(); + + // Capture initial callable-value count for before/after progress tracking + // (mirrors LLVM DevirtSCCRepeatedPass: detect when an iteration fails to + // reduce the remaining work set). + let (_, mut prev_remaining_count, _) = remaining_callable_value_info(store, package_id); + + while iteration_count < max_iterations { + iteration_count += 1; + + // Clear DynamicCallable errors from prior iterations. These are + // re-discovered each pass, and transient ones (e.g. parameter + // forwarding like `Inner(op, q)` inside a HOF that hasn't been + // specialized yet) disappear once the outer HOF is specialized. + errors.retain(|e| !matches!(e, Error::DynamicCallable(_))); + + let reachable = collect_reachable_from_entry(store, package_id); + + let (local_item_ids, reachable_expr_ids) = + collect_reachable_scope(store, package_id, &reachable); + + // Simplify defunctionalization analysis by eliminating callable + // indirection patterns and exposing direct call sites. + prepass::run(store, package_id, &reachable_expr_ids); + + let analysis = analysis::analyze(store, package_id, &reachable); + + let spec_map = run_specialization( + store, + package_id, + &analysis, + assigner, + &mut errors, + &mut warnings, + ); + + // Rewrite call sites and run dead callable-local cleanup even on + // iterations where no new specializations were discovered. + let package = store.get_mut(package_id); + rewrite::rewrite(package, package_id, &analysis, &spec_map, assigner); + + track_specialized_closures( + &analysis, + &spec_map, + &mut specialized_closure_targets, + &mut specialized_items, + ); + cleanup_consumed_closures( + package, + &specialized_closure_targets, + &specialized_items, + &local_item_ids, + ); + + // Check convergence + let converged = check_convergence( + store, + package_id, + &analysis, + iteration_count, + &mut max_iterations, + &mut prev_remaining_count, + ); + if converged { + break; + } + } + + emit_fixpoint_error(store, package_id, iteration_count, &mut errors); + errors.extend(warnings); + + errors +} + +/// Computes the reachable local callable IDs and expression IDs for scoping +/// the prepass and cleanup to entry-reachable code. +fn collect_reachable_scope( + store: &PackageStore, + package_id: PackageId, + reachable: &FxHashSet, +) -> (Vec, Vec) { + let package = store.get(package_id); + let local_item_ids: Vec<_> = reachable_local_callables(package, package_id, reachable) + .map(|(id, _)| id) + .collect(); + let reachable_expr_ids = + collect_expr_ids_in_entry_and_local_callables(package, &local_item_ids); + (local_item_ids, reachable_expr_ids) +} + +/// Runs specialization if there are call sites, separating warnings from +/// errors. Returns the specialization map. +fn run_specialization( + store: &mut PackageStore, + package_id: PackageId, + analysis: &AnalysisResult, + assigner: &mut Assigner, + errors: &mut Vec, + warnings: &mut Vec, +) -> FxHashMap { + let (spec_map, mut spec_errors) = if analysis.call_sites.is_empty() { + (Default::default(), Vec::new()) + } else { + specialize::specialize(store, package_id, analysis, assigner) + }; + // Separate warnings from errors so the `retain` at the top of each + // iteration does not discard them. + warnings.extend( + spec_errors + .iter() + .filter(|e| matches!(e, Error::ExcessiveSpecializations(..))) + .cloned(), + ); + spec_errors.retain(|e| !matches!(e, Error::ExcessiveSpecializations(..))); + errors.append(&mut spec_errors); + spec_map +} + +/// Records which closure targets were consumed by specialization in this +/// iteration. +fn track_specialized_closures( + analysis: &AnalysisResult, + spec_map: &FxHashMap, + specialized_closure_targets: &mut FxHashSet, + specialized_items: &mut FxHashSet, +) { + for cs in &analysis.call_sites { + let spec_key = build_spec_key(cs); + if spec_map.contains_key(&spec_key) + && let ConcreteCallable::Closure { target, .. } = &cs.callable_arg + { + specialized_closure_targets.insert(*target); + } + } + specialized_items.extend(spec_map.values().copied()); +} + +/// Checks whether the fixed-point loop should terminate. Returns `true` when +/// the loop should break (converged or stuck). +fn check_convergence( + store: &PackageStore, + package_id: PackageId, + analysis: &AnalysisResult, + iteration_count: usize, + max_iterations: &mut usize, + prev_remaining_count: &mut usize, +) -> bool { + let (has_remaining, remaining_count, _) = remaining_callable_value_info(store, package_id); + + let made_progress = remaining_count < *prev_remaining_count || !analysis.call_sites.is_empty(); + *prev_remaining_count = remaining_count; + + // On the first iteration, compute a dynamic iteration limit based on + // the number of remaining callable values discovered. + if iteration_count == 1 { + *max_iterations = analysis + .callable_params + .len() + .max(remaining_count) + .clamp(MAX_ITERATIONS, 20); + } + + if !has_remaining { + return true; + } + + // No progress was made — the loop is stuck. Break out and let + // `emit_fixpoint_error` report the remaining callable values. + if !made_progress { + return true; + } + + false +} + +/// Emits a `FixpointNotReached` error if callable values remain after the +/// loop exits. +fn emit_fixpoint_error( + store: &PackageStore, + package_id: PackageId, + iteration_count: usize, + errors: &mut Vec, +) { + let (has_remaining, remaining_count, span) = remaining_callable_value_info(store, package_id); + if has_remaining && errors.is_empty() { + errors.push(Error::FixpointNotReached( + iteration_count, + remaining_count, + span, + )); + } +} + +/// Replaces all remaining closure expressions whose target callable was +/// consumed by specialization with Unit values, clearing references so +/// subsequent iterations do not count them as work remaining. +/// +/// A closure is "consumed" when its target callable has been specialized — +/// meaning the HOF call site that passed this closure as an argument has been +/// rewritten to a direct call to the specialized version. The closure node +/// in the producer function body is now dead: no analysis will discover new +/// call sites for it, but `remaining_callable_value_info` would still count +/// it as work remaining, causing false convergence failure. +/// +/// Only closures that are NOT direct children of a `Call` argument subtree +/// are eligible for cleanup. Closures that are still live as arguments to a +/// call expression (e.g., in a multi-param HOF where only one param has been +/// specialized so far) must survive to the next iteration. +/// +/// # Before +/// ```text +/// Closure([captures], consumed_target) : Arrow +/// ``` +/// # After +/// ```text +/// Tuple([]) : Unit // closure replaced with unit +/// ``` +/// +/// # Mutations +/// - Rewrites `Expr.kind` to `Tuple(Vec::new())` and `Expr.ty` to `Unit` +/// for consumed closure expressions outside call-argument subtrees. +fn cleanup_consumed_closures( + package: &mut Package, + specialized_targets: &FxHashSet, + skip_items: &FxHashSet, + reachable_item_ids: &[LocalItemId], +) -> usize { + if specialized_targets.is_empty() { + return 0; + } + + // First pass: collect the ExprIds of all call argument subtrees. + // Closures inside these subtrees are still live as HOF arguments. + let mut call_arg_exprs: FxHashSet = FxHashSet::default(); + for &item_id in reachable_item_ids { + if skip_items.contains(&item_id) { + continue; + } + let item = package.get_item(item_id); + if let ItemKind::Callable(decl) = &item.kind { + crate::walk_utils::for_each_expr_in_callable_impl( + package, + &decl.implementation, + &mut |_expr_id, expr| { + if let ExprKind::Call(_, args_id) = &expr.kind { + collect_all_expr_ids(package, *args_id, &mut call_arg_exprs); + } + }, + ); + } + } + if let Some(entry_id) = package.entry { + crate::walk_utils::for_each_expr(package, entry_id, &mut |_expr_id, expr| { + if let ExprKind::Call(_, args_id) = &expr.kind { + collect_all_expr_ids(package, *args_id, &mut call_arg_exprs); + } + }); + } + + // Second pass: collect consumed closures that are NOT in call argument + // positions. + let mut to_replace: Vec = Vec::new(); + for &item_id in reachable_item_ids { + if skip_items.contains(&item_id) { + continue; + } + let item = package.get_item(item_id); + if let ItemKind::Callable(decl) = &item.kind { + crate::walk_utils::for_each_expr_in_callable_impl( + package, + &decl.implementation, + &mut |expr_id, expr| { + if let ExprKind::Closure(_, target) = &expr.kind + && specialized_targets.contains(target) + && !call_arg_exprs.contains(&expr_id) + { + to_replace.push(expr_id); + } + }, + ); + } + } + + if let Some(entry_id) = package.entry { + crate::walk_utils::for_each_expr(package, entry_id, &mut |expr_id, expr| { + if let ExprKind::Closure(_, target) = &expr.kind + && specialized_targets.contains(target) + && !call_arg_exprs.contains(&expr_id) + { + to_replace.push(expr_id); + } + }); + } + + let count = to_replace.len(); + for expr_id in to_replace { + let expr = package.exprs.get_mut(expr_id).expect("expr must exist"); + expr.kind = ExprKind::Tuple(Vec::new()); + expr.ty = Ty::UNIT; + } + + count +} + +/// Recursively collects all `ExprId`s reachable from an expression node. +fn collect_all_expr_ids(package: &Package, expr_id: ExprId, ids: &mut FxHashSet) { + crate::walk_utils::for_each_expr(package, expr_id, &mut |child_id, _| { + ids.insert(child_id); + }); +} + +/// Checks whether any reachable target-package callable value still requires +/// defunctionalization work. +/// +/// Returns `(has_remaining, count, first_span)` in a single reachability scan. +fn remaining_callable_value_info( + store: &PackageStore, + package_id: PackageId, +) -> (bool, usize, Span) { + let reachable = collect_reachable_from_entry(store, package_id); + let package = store.get(package_id); + let mut count = 0; + let mut first_span = Span::default(); + + let mut record_remaining = |span: Span| { + if count == 0 { + first_span = span; + } + count += 1; + }; + + for store_id in &reachable { + if store_id.package != package_id { + continue; + } + let item = package.get_item(store_id.item); + if let ItemKind::Callable(decl) = &item.kind { + let input_pat = package.get_pat(decl.input); + if ty_contains_arrow_through_udts(store, &input_pat.ty) { + record_remaining(input_pat.span); + } + + crate::walk_utils::for_each_expr_in_callable_impl( + package, + &decl.implementation, + &mut |_expr_id, expr| { + if matches!(expr.kind, ExprKind::Closure(_, _)) { + record_remaining(expr.span); + } + // Count indirect calls through arrow-typed local variables. + // After defunc iteration 1 specializes HOFs and removes callable + // parameters, conditional callable bindings like + // let u = if power >= 0 { op } else { Adjoint op }; + // u(target); + // leave arrow-typed locals with indirect Call expressions. + // The existing branch-split infrastructure resolves these in + // a subsequent iteration, but only if the convergence check + // reports them as remaining. + if let ExprKind::Call(callee_id, _) = &expr.kind { + let (base_id, _) = peel_body_functors(package, *callee_id); + let base_expr = package.get_expr(base_id); + if matches!(base_expr.kind, ExprKind::Var(Res::Local(_), _)) + && ty_contains_arrow(&base_expr.ty) + { + record_remaining(base_expr.span); + } + } + }, + ); + } + } + + if let Some(entry_id) = package.entry { + crate::walk_utils::for_each_expr(package, entry_id, &mut |_expr_id, expr| { + if matches!(expr.kind, ExprKind::Closure(_, _)) { + record_remaining(expr.span); + } + // Same indirect-call check as callable body walker. + if let ExprKind::Call(callee_id, _) = &expr.kind { + let (base_id, _) = peel_body_functors(package, *callee_id); + let base_expr = package.get_expr(base_id); + if matches!(base_expr.kind, ExprKind::Var(Res::Local(_), _)) + && ty_contains_arrow(&base_expr.ty) + { + record_remaining(base_expr.span); + } + } + }); + } + + (count > 0, count, first_span) +} + +/// Checks whether a type contains an arrow type anywhere within its structure. +/// +/// This intentionally does NOT recurse into `Ty::Udt` or `Ty::Array`: +/// +/// - **`Ty::Udt`**: Defunc runs before UDT erasure, so UDT wrappers are still +/// opaque here. Callable values inside UDTs are handled at the *expression* +/// level by the analysis phase (`extract_arrow_params_from_ty` also ignores +/// `Ty::Udt`, but `build_callable_flow_state` tracks field-extraction +/// expressions like `config.Op` to resolve concrete callable values). After +/// defunc, callable values are either specialized or rejected as +/// `DynamicCallable`. Post-UDT-erasure passes (SROA, `arg_promote`) may expose +/// bare `Ty::Arrow` parameters, but partial eval handles them correctly +/// because it dispatches on *values* (`Value::Global` / `Value::Closure`), +/// not on the `Ty::Arrow` type annotation. +/// +/// - **`Ty::Array`**: Array-of-callable parameters (`(Qubit => Unit)[]`) are +/// dynamically indexed, so defunc cannot specialize them. Ignoring +/// `Ty::Array` is consistent with defunc's capabilities. +/// +/// A separate copy of this function in `codegen.rs` does handle `Ty::Array` +/// for codegen routing; unifying the two is unnecessary because their +/// contexts differ. +pub(crate) fn ty_contains_arrow(ty: &Ty) -> bool { + match ty { + Ty::Arrow(_) => true, + Ty::Tuple(tys) => tys.iter().any(ty_contains_arrow), + _ => false, + } +} + +/// Checks whether a type contains an arrow, expanding UDT pure types recursively. +/// +/// The defunctionalization fixpoint uses this for reachable callable inputs so a +/// callable whose parameter is a UDT containing a callable field keeps the loop +/// running until that nested callable field is specialized. The rewrite helpers +/// still use `ty_contains_arrow`, where UDTs intentionally remain opaque. +fn ty_contains_arrow_through_udts(store: &PackageStore, ty: &Ty) -> bool { + match ty { + Ty::Arrow(_) => true, + Ty::Tuple(tys) => tys + .iter() + .any(|ty| ty_contains_arrow_through_udts(store, ty)), + Ty::Udt(Res::Item(item_id)) => { + let package = store.get(item_id.package); + let item = package.get_item(item_id.item); + let ItemKind::Ty(_, udt) = &item.kind else { + return false; + }; + ty_contains_arrow_through_udts(store, &udt.get_pure_ty()) + } + _ => false, + } +} + +/// Builds the deduplication key for a call site's specialization. +pub(crate) fn build_spec_key(call_site: &CallSite) -> SpecKey { + let concrete_key = match &call_site.callable_arg { + ConcreteCallable::Global { item_id, functor } => ConcreteCallableKey::Global { + item_id: *item_id, + functor: *functor, + }, + ConcreteCallable::Closure { + target, functor, .. + } => ConcreteCallableKey::Closure { + target: *target, + functor: *functor, + }, + ConcreteCallable::Dynamic => { + // Dynamic callables are filtered out before reaching here, but + // provide a deterministic key regardless. + ConcreteCallableKey::Global { + item_id: call_site.hof_item_id, + functor: FunctorApp::default(), + } + } + }; + SpecKey { + hof_id: call_site.hof_item_id.item, + concrete_args: vec![concrete_key], + } +} + +/// Builds the index path from a call's argument tuple to the position of +/// a callable parameter, accounting for functor control wrappers and +/// tuple-patterned inputs. +pub(crate) fn build_param_input_path( + uses_tuple_input: bool, + param: &CallableParam, + functor: FunctorApp, +) -> Vec { + let mut path = vec![1; usize::from(functor.controlled)]; + if uses_tuple_input { + path.push(param.top_level_param); + } + path.extend(param.field_path.iter().copied()); + path +} diff --git a/source/compiler/qsc_fir_transforms/src/defunctionalize/analysis.rs b/source/compiler/qsc_fir_transforms/src/defunctionalize/analysis.rs new file mode 100644 index 0000000000..f356894afd --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/defunctionalize/analysis.rs @@ -0,0 +1,1872 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! Analysis phase of the defunctionalization pass. +//! +//! Discovers callable-typed parameters in higher-order functions, collects +//! call sites where those HOFs are invoked with concrete callable arguments, +//! and resolves each argument to a [`ConcreteCallable`]. +//! +//! # Responsibilities +//! +//! - Discover arrow-typed callable parameters on reachable declarations +//! (via [`find_callable_params`] / [`extract_arrow_params_from_ty`]). +//! - Collect direct and HOF call sites (via [`collect_call_sites`] / +//! [`inspect_call_expr`] / [`inspect_direct_call_expr`]). +//! - Run the identity-closure peephole that replaces `(args) => f(args)` +//! closures with direct references to `f` (via +//! [`identity_closure_peephole`]). +//! - Resolve callee expressions to concrete callables using flow-sensitive +//! reaching definitions, closure captures, functor applications, indexed +//! array elements, struct field accesses, and same-package callable +//! returns (via [`resolve_callee`] and its helpers). +//! - Build per-callable lattice states that expose reaching-definition +//! information back to the specialization and rewrite phases (via +//! [`build_callable_flow_state`] / [`analyze_spec_flow`]). + +use super::types::{ + AnalysisResult, CallSite, CallableParam, CalleeLattice, CapturedVar, ConcreteCallable, + DirectCallSite, LatticeStates, compose_functors, peel_body_functors, +}; +use crate::fir_builder::functored_specs; +use qsc_data_structures::functors::FunctorApp; +use qsc_fir::fir::{ + BlockId, CallableImpl, ExprId, ExprKind, Field, FieldAssign, FieldPath, ItemId, ItemKind, Lit, + LocalVarId, Mutability, Package, PackageId, PackageLookup, PackageStore, PatId, PatKind, Res, + SpecImpl, StmtKind, StoreItemId, UnOp, +}; +use qsc_fir::ty::Ty; +use rustc_hash::{FxHashMap, FxHashSet}; + +/// Combined local variable state for the analysis phase. +/// +/// `callable` holds flow-sensitive reaching-definitions for callable-typed +/// locals (both mutable and immutable). `exprs` holds raw `ExprId` bindings +/// for all immutable locals, supporting struct field resolution and type +/// look-ups. +#[derive(Default)] +pub(super) struct LocalState { + callable: FxHashMap, + exprs: FxHashMap, +} + +/// Maximum recursion depth when resolving callee expressions to prevent +/// infinite loops from unexpected circular references. +const MAX_RESOLVE_DEPTH: usize = 32; + +/// Runs the analysis phase: finds callable parameters and collects call sites. +pub(super) fn analyze( + store: &mut PackageStore, + package_id: PackageId, + reachable: &FxHashSet, +) -> AnalysisResult { + let hof_params = find_callable_params(store, reachable); + let (call_sites, direct_call_sites, lattice_states) = + collect_call_sites(store, package_id, reachable, &hof_params); + AnalysisResult { + callable_params: hof_params.into_values().flatten().collect(), + call_sites, + direct_call_sites, + lattice_states, + } +} + +/// Scans all reachable callables (including cross-package ones like the +/// standard library) and returns a map from each HOF's `StoreItemId` to the +/// list of its arrow-typed parameters. +fn find_callable_params( + store: &PackageStore, + reachable: &FxHashSet, +) -> FxHashMap> { + let mut result: FxHashMap> = FxHashMap::default(); + + for &store_id in reachable { + let pkg = store.get(store_id.package); + let item = pkg.get_item(store_id.item); + if let ItemKind::Callable(decl) = &item.kind { + let params = extract_arrow_params(store, pkg, store_id.item, decl.input); + if !params.is_empty() { + result.insert(store_id, params); + } + } + } + + result +} + +/// Extracts arrow-typed parameters from a callable's input pattern. +fn extract_arrow_params( + store: &PackageStore, + pkg: &Package, + callable_id: qsc_fir::fir::LocalItemId, + input_pat_id: qsc_fir::fir::PatId, +) -> Vec { + let pat = pkg.get_pat(input_pat_id); + let mut params = Vec::new(); + + match &pat.kind { + PatKind::Tuple(sub_pats) => { + for (index, &sub_pat_id) in sub_pats.iter().enumerate() { + let sub_pat = pkg.get_pat(sub_pat_id); + if let PatKind::Bind(ident) = &sub_pat.kind { + let mut field_path = Vec::new(); + let context = ArrowParamExtraction { + store, + callable_id, + param_pat_id: sub_pat_id, + param_var: ident.id, + top_level_param: index, + }; + extract_arrow_params_from_ty( + &context, + &sub_pat.ty, + &mut field_path, + &mut params, + ); + } + } + } + PatKind::Bind(ident) => { + let mut field_path = Vec::new(); + let context = ArrowParamExtraction { + store, + callable_id, + param_pat_id: input_pat_id, + param_var: ident.id, + top_level_param: 0, + }; + extract_arrow_params_from_ty(&context, &pat.ty, &mut field_path, &mut params); + } + PatKind::Discard => {} + } + + params +} + +/// Carries the invariant metadata needed while extracting callable parameters. +struct ArrowParamExtraction<'a> { + store: &'a PackageStore, + callable_id: qsc_fir::fir::LocalItemId, + param_pat_id: PatId, + param_var: LocalVarId, + top_level_param: usize, +} + +/// Recursively descends into the structural layers of a callable parameter +/// type and records every `Ty::Arrow` leaf as a `CallableParam`. +/// +/// UDTs are expanded to their pure type so callable fields inside nested +/// newtypes are treated the same way as tuple fields. +fn extract_arrow_params_from_ty( + context: &ArrowParamExtraction<'_>, + param_ty: &Ty, + field_path: &mut Vec, + params: &mut Vec, +) { + match param_ty { + Ty::Arrow(_) => params.push(CallableParam::new( + context.callable_id, + context.param_pat_id, + context.top_level_param, + field_path.clone(), + context.param_var, + param_ty.clone(), + )), + Ty::Tuple(items) => { + for (index, item_ty) in items.iter().enumerate() { + field_path.push(index); + extract_arrow_params_from_ty(context, item_ty, field_path, params); + field_path.pop(); + } + } + Ty::Udt(Res::Item(item_id)) => { + let package = context.store.get(item_id.package); + let item = package.get_item(item_id.item); + let ItemKind::Ty(_, udt) = &item.kind else { + return; + }; + extract_arrow_params_from_ty(context, &udt.get_pure_ty(), field_path, params); + } + _ => {} + } +} + +/// Walks the bodies of all reachable callables in the target package and +/// collects call sites where a HOF is invoked with a concrete callable +/// argument. +fn collect_call_sites( + store: &PackageStore, + package_id: PackageId, + reachable: &FxHashSet, + hof_params: &FxHashMap>, +) -> (Vec, Vec, LatticeStates) { + let package = store.get(package_id); + let mut call_sites = Vec::new(); + let mut direct_call_sites = Vec::new(); + let mut lattice_states: LatticeStates = FxHashMap::default(); + + for &store_id in reachable { + if store_id.package != package_id { + continue; + } + let item = package.get_item(store_id.item); + if let ItemKind::Callable(decl) = &item.kind { + let locals = + build_callable_flow_state(package, store, &decl.implementation, package_id); + + // Capture non-Bottom lattice entries, sorted by LocalVarId. + let mut entries: Vec<(LocalVarId, CalleeLattice)> = locals + .callable + .iter() + .filter(|(_, lat)| !matches!(lat, CalleeLattice::Bottom)) + .map(|(var, lat)| (*var, lat.clone())) + .collect(); + entries.sort_by_key(|(var, _)| *var); + if !entries.is_empty() { + lattice_states.insert(store_id.item, entries); + } + + walk_callable_for_calls( + store, + package, + &decl.implementation, + hof_params, + &locals, + &mut call_sites, + &mut direct_call_sites, + package_id, + ); + } + } + + if let Some(entry_expr_id) = package.entry { + let mut locals = LocalState { + callable: FxHashMap::default(), + exprs: FxHashMap::default(), + }; + analyze_expr_flow(package, store, entry_expr_id, &mut locals, package_id); + crate::walk_utils::for_each_expr(package, entry_expr_id, &mut |expr_id, expr| { + inspect_call_expr( + store, + package, + expr_id, + expr, + hof_params, + &locals, + &mut call_sites, + &mut direct_call_sites, + package_id, + ); + }); + } + + (call_sites, direct_call_sites, lattice_states) +} + +/// Walks the specialisation bodies of a callable implementation looking for +/// `ExprKind::Call` nodes whose callee is a known HOF. +#[allow(clippy::too_many_arguments)] +fn walk_callable_for_calls( + store: &PackageStore, + pkg: &Package, + callable_impl: &CallableImpl, + hof_params: &FxHashMap>, + locals: &LocalState, + call_sites: &mut Vec, + direct_call_sites: &mut Vec, + package_id: PackageId, +) { + crate::walk_utils::for_each_expr_in_callable_impl(pkg, callable_impl, &mut |expr_id, expr| { + inspect_call_expr( + store, + pkg, + expr_id, + expr, + hof_params, + locals, + call_sites, + direct_call_sites, + package_id, + ); + }); +} + +/// Inspects a single expression for HOF call-site patterns. +#[allow(clippy::too_many_arguments)] +fn inspect_call_expr( + store: &PackageStore, + pkg: &Package, + expr_id: ExprId, + expr: &qsc_fir::fir::Expr, + hof_params: &FxHashMap>, + locals: &LocalState, + call_sites: &mut Vec, + direct_call_sites: &mut Vec, + package_id: PackageId, +) { + let ExprKind::Call(callee_expr_id, args_expr_id) = &expr.kind else { + return; + }; + + if expr_contains_hole(pkg, *args_expr_id) { + return; + } + + if let Some((hof_store_id, hof_functor, hof_callable_params)) = + resolve_hof_callee(pkg, *callee_expr_id, hof_params) + { + let uses_tuple_input = hof_uses_tuple_input_pattern(store, hof_store_id); + for cp in hof_callable_params { + let input_path = super::build_param_input_path(uses_tuple_input, cp, hof_functor); + let resolved_arg_id = extract_arg_at_path(pkg, *args_expr_id, &input_path); + let allow_scoped_capture_exprs = matches!( + pkg.get_expr(resolved_arg_id).kind, + ExprKind::Block(_) | ExprKind::If(_, _, _) + ); + let resolved = resolve_callee_at_path( + pkg, + store, + locals, + *args_expr_id, + &input_path, + 0, + allow_scoped_capture_exprs, + &FxHashSet::default(), + package_id, + ); + match resolved { + CalleeLattice::Single(cc) => { + call_sites.push(CallSite { + call_expr_id: expr_id, + hof_item_id: ItemId { + package: hof_store_id.package, + item: hof_store_id.item, + }, + callable_arg: cc, + arg_expr_id: resolved_arg_id, + condition: None, + }); + } + CalleeLattice::Multi(candidates) => { + for (cc, cond) in candidates { + call_sites.push(CallSite { + call_expr_id: expr_id, + hof_item_id: ItemId { + package: hof_store_id.package, + item: hof_store_id.item, + }, + callable_arg: cc, + arg_expr_id: resolved_arg_id, + condition: cond, + }); + } + } + CalleeLattice::Dynamic | CalleeLattice::Bottom => { + call_sites.push(CallSite { + call_expr_id: expr_id, + hof_item_id: ItemId { + package: hof_store_id.package, + item: hof_store_id.item, + }, + callable_arg: ConcreteCallable::Dynamic, + arg_expr_id: resolved_arg_id, + condition: None, + }); + } + } + } + + return; + } + + inspect_direct_call_expr( + store, + pkg, + expr_id, + *callee_expr_id, + locals, + direct_call_sites, + package_id, + ); +} + +/// Returns `true` when an expression subtree contains an `ExprKind::Hole` +/// placeholder, which marks partial applications that the pass does not +/// yet specialize. +fn expr_contains_hole(pkg: &Package, expr_id: ExprId) -> bool { + let mut contains_hole = false; + crate::walk_utils::for_each_expr(pkg, expr_id, &mut |_expr_id, expr| { + if matches!(expr.kind, ExprKind::Hole) { + contains_hole = true; + } + }); + contains_hole +} + +/// Inspects a direct `Call(callee, args)` expression whose callee resolves +/// to a concrete callable value (global, closure, or functor-applied +/// callable) and, when resolution succeeds, records a [`DirectCallSite`]. +fn inspect_direct_call_expr( + store: &PackageStore, + pkg: &Package, + expr_id: ExprId, + callee_expr_id: ExprId, + locals: &LocalState, + direct_call_sites: &mut Vec, + package_id: PackageId, +) { + let callee_expr = pkg.get_expr(callee_expr_id); + if matches!(callee_expr.kind, ExprKind::Var(Res::Item(_), _)) { + return; + } + + let resolved = if let ExprKind::Var(Res::Local(var), _) = callee_expr.kind { + if let Some(&init_expr_id) = locals.exprs.get(&var) { + resolve_callee( + pkg, + store, + locals, + init_expr_id, + 0, + true, + &FxHashSet::default(), + package_id, + ) + } else { + resolve_callee( + pkg, + store, + locals, + callee_expr_id, + 0, + false, + &FxHashSet::default(), + package_id, + ) + } + } else { + let allow_scoped_capture_exprs = matches!( + callee_expr.kind, + ExprKind::Block(_) | ExprKind::If(_, _, _) | ExprKind::UnOp(_, _) + ); + resolve_callee( + pkg, + store, + locals, + callee_expr_id, + 0, + allow_scoped_capture_exprs, + &FxHashSet::default(), + package_id, + ) + }; + + match resolved { + CalleeLattice::Single(callable) => { + direct_call_sites.push(DirectCallSite { + call_expr_id: expr_id, + callable, + condition: None, + }); + } + CalleeLattice::Multi(candidates) => { + for (callable, condition) in candidates { + direct_call_sites.push(DirectCallSite { + call_expr_id: expr_id, + callable, + condition, + }); + } + } + CalleeLattice::Bottom | CalleeLattice::Dynamic => {} + } +} + +/// Given a callee expression, peel functor layers and check whether the base +/// refers to a callable in the `hof_params` map. Returns the `StoreItemId` of +/// the HOF and a reference to its callable-typed parameters. +fn resolve_hof_callee<'a>( + pkg: &Package, + callee_expr_id: ExprId, + hof_params: &'a FxHashMap>, +) -> Option<(StoreItemId, FunctorApp, &'a Vec)> { + let (base_id, functor) = peel_body_functors(pkg, callee_expr_id); + let base_expr = pkg.get_expr(base_id); + if let ExprKind::Var(Res::Item(item_id), _) = &base_expr.kind { + let store_id = StoreItemId { + package: item_id.package, + item: item_id.item, + }; + hof_params + .get(&store_id) + .map(|params| (store_id, functor, params)) + } else { + None + } +} + +/// Returns `true` when the HOF's input pattern is a single tuple pattern +/// bound to one name. Used to gate tuple-field locator bookkeeping for HOFs +/// whose arrow parameter is nested inside a single tuple binding. +fn hof_uses_tuple_input_pattern(store: &PackageStore, hof_store_id: StoreItemId) -> bool { + let hof_pkg = store.get(hof_store_id.package); + let hof_item = hof_pkg.get_item(hof_store_id.item); + match &hof_item.kind { + ItemKind::Callable(decl) => matches!(hof_pkg.get_pat(decl.input).kind, PatKind::Tuple(_)), + _ => false, + } +} + +/// Extracts the argument expression at the given relative field path from an +/// already-selected outer call argument. +fn extract_arg_at_path(pkg: &Package, args_expr_id: ExprId, path: &[usize]) -> ExprId { + if path.is_empty() { + return args_expr_id; + } + let args_expr = pkg.get_expr(args_expr_id); + if let ExprKind::Tuple(elements) = &args_expr.kind { + if path.len() == 1 { + elements[path[0]] + } else { + extract_arg_at_path(pkg, elements[path[0]], &path[1..]) + } + } else { + // Single-parameter callable: the args expression IS the argument. + args_expr_id + } +} + +/// Resolves a callable argument selected by `path`, following local UDT/tuple +/// initializers when the selected value is nested inside a single argument. +#[allow(clippy::too_many_arguments)] +fn resolve_callee_at_path( + pkg: &Package, + store: &PackageStore, + locals: &LocalState, + args_expr_id: ExprId, + path: &[usize], + depth: usize, + allow_scoped_capture_exprs: bool, + scoped_capture_vars: &FxHashSet, + package_id: PackageId, +) -> CalleeLattice { + if depth > MAX_RESOLVE_DEPTH { + return CalleeLattice::Dynamic; + } + + if path.is_empty() { + return resolve_callee( + pkg, + store, + locals, + args_expr_id, + depth + 1, + allow_scoped_capture_exprs, + scoped_capture_vars, + package_id, + ); + } + + let args_expr = pkg.get_expr(args_expr_id); + if let ExprKind::Tuple(elements) = &args_expr.kind + && let Some(&element_id) = elements.get(path[0]) + { + return resolve_callee_at_path( + pkg, + store, + locals, + element_id, + &path[1..], + depth + 1, + allow_scoped_capture_exprs, + scoped_capture_vars, + package_id, + ); + } + + let field_path = FieldPath { + indices: path.to_vec(), + }; + if let Some(field_value_id) = resolve_struct_field(pkg, locals, args_expr_id, &field_path, 0) { + return resolve_callee( + pkg, + store, + locals, + field_value_id, + depth + 1, + allow_scoped_capture_exprs, + scoped_capture_vars, + package_id, + ); + } + + resolve_callee( + pkg, + store, + locals, + args_expr_id, + depth + 1, + allow_scoped_capture_exprs, + scoped_capture_vars, + package_id, + ) +} + +/// Resolves an expression to a [`CalleeLattice`] by peeling functor +/// applications, following single-assignment immutable locals, resolving +/// if-value-expressions, and recognising closures and global item references. +#[allow( + clippy::only_used_in_recursion, + clippy::too_many_lines, + clippy::too_many_arguments +)] +fn resolve_callee( + pkg: &Package, + store: &PackageStore, + locals: &LocalState, + expr_id: ExprId, + depth: usize, + allow_scoped_capture_exprs: bool, + scoped_capture_vars: &FxHashSet, + package_id: PackageId, +) -> CalleeLattice { + if depth > MAX_RESOLVE_DEPTH { + return CalleeLattice::Dynamic; + } + + // First peel any functor application layers. + let (base_id, outer_functor) = peel_body_functors(pkg, expr_id); + let base_expr = pkg.get_expr(base_id); + + let base_resolved = match &base_expr.kind { + ExprKind::Var(Res::Item(item_id), _) => CalleeLattice::Single(ConcreteCallable::Global { + item_id: *item_id, + functor: FunctorApp::default(), + }), + ExprKind::Closure(captured_vars, target) => { + let Some(captures) = resolve_captures(pkg, locals, captured_vars, scoped_capture_vars) + else { + return CalleeLattice::Dynamic; + }; + CalleeLattice::Single(ConcreteCallable::Closure { + target: *target, + captures, + functor: FunctorApp::default(), + }) + } + ExprKind::Var(Res::Local(var), _) => { + // Check flow-sensitive callable lattice first. + if let Some(lattice) = locals.callable.get(var) { + lattice.clone() + } else if let Some(&init_expr_id) = locals.exprs.get(var) { + // Fallback to immutable ExprId bindings (struct fields, etc.). + resolve_callee( + pkg, + store, + locals, + init_expr_id, + depth + 1, + allow_scoped_capture_exprs, + scoped_capture_vars, + package_id, + ) + } else { + CalleeLattice::Dynamic + } + } + ExprKind::Return(inner_expr_id) => resolve_callee( + pkg, + store, + locals, + *inner_expr_id, + depth + 1, + allow_scoped_capture_exprs, + scoped_capture_vars, + package_id, + ), + ExprKind::Call(callee_expr_id, args_expr_id) => { + let callee_lattice = resolve_callee( + pkg, + store, + locals, + *callee_expr_id, + depth + 1, + allow_scoped_capture_exprs, + scoped_capture_vars, + package_id, + ); + + match callee_lattice { + CalleeLattice::Single(ConcreteCallable::Global { item_id, functor }) + if item_id.package == package_id && functor == FunctorApp::default() => + { + resolve_same_package_callable_return( + pkg, + store, + locals, + item_id, + *args_expr_id, + depth + 1, + allow_scoped_capture_exprs, + scoped_capture_vars, + package_id, + ) + } + _ => CalleeLattice::Dynamic, + } + } + ExprKind::Index(array_expr_id, index_expr_id) => { + if let Some(elem_expr_id) = resolve_indexed_array_element( + pkg, + locals, + *array_expr_id, + *index_expr_id, + depth + 1, + ) { + resolve_callee( + pkg, + store, + locals, + elem_expr_id, + depth + 1, + allow_scoped_capture_exprs, + scoped_capture_vars, + package_id, + ) + } else if let Some(candidates) = resolve_indexed_callable_candidates( + pkg, + store, + locals, + *array_expr_id, + depth + 1, + allow_scoped_capture_exprs, + scoped_capture_vars, + package_id, + ) { + CalleeLattice::Multi( + candidates + .into_iter() + .map(|callable| (callable, None)) + .collect(), + ) + } else { + CalleeLattice::Dynamic + } + } + ExprKind::If(cond, body, otherwise) => { + let true_res = resolve_callee( + pkg, + store, + locals, + *body, + depth + 1, + allow_scoped_capture_exprs, + scoped_capture_vars, + package_id, + ); + let false_res = if let Some(else_id) = otherwise { + resolve_callee( + pkg, + store, + locals, + *else_id, + depth + 1, + allow_scoped_capture_exprs, + scoped_capture_vars, + package_id, + ) + } else { + CalleeLattice::Dynamic + }; + true_res.join_with_condition(false_res, *cond) + } + ExprKind::Block(block_id) => { + let block = pkg.get_block(*block_id); + let mut block_state = LocalState { + callable: locals.callable.clone(), + exprs: locals.exprs.clone(), + }; + analyze_block_flow(pkg, store, *block_id, &mut block_state, package_id); + let block_scoped_vars = if allow_scoped_capture_exprs { + let mut vars = scoped_capture_vars.clone(); + collect_block_local_bindings(pkg, *block_id, &mut vars); + vars + } else { + scoped_capture_vars.clone() + }; + if let Some(&last_stmt_id) = block.stmts.last() { + let stmt = pkg.get_stmt(last_stmt_id); + match &stmt.kind { + StmtKind::Expr(e) | StmtKind::Semi(e) => resolve_callee( + pkg, + store, + &block_state, + *e, + depth + 1, + allow_scoped_capture_exprs, + &block_scoped_vars, + package_id, + ), + _ => CalleeLattice::Dynamic, + } + } else { + CalleeLattice::Dynamic + } + } + ExprKind::Field(inner_expr_id, Field::Path(path)) => { + if let Some(field_value_id) = + resolve_struct_field(pkg, locals, *inner_expr_id, path, depth + 1) + { + resolve_callee( + pkg, + store, + locals, + field_value_id, + depth + 1, + allow_scoped_capture_exprs, + scoped_capture_vars, + package_id, + ) + } else { + CalleeLattice::Dynamic + } + } + _ => CalleeLattice::Dynamic, + }; + + // Compose the outer functor (from peeling) with the base's functor. + apply_outer_functor_lattice(base_resolved, outer_functor) +} + +/// Attempts to resolve a callable-returning call whose target lives in the +/// same package by treating the target body as a straight-line function, +/// binding its parameters to the call's argument expressions and tracing +/// the result back to a concrete callable. +#[allow(clippy::too_many_arguments)] +fn resolve_same_package_callable_return( + pkg: &Package, + store: &PackageStore, + caller_locals: &LocalState, + item_id: ItemId, + args_expr_id: ExprId, + depth: usize, + allow_scoped_capture_exprs: bool, + scoped_capture_vars: &FxHashSet, + package_id: PackageId, +) -> CalleeLattice { + let item = pkg.get_item(item_id.item); + let ItemKind::Callable(decl) = &item.kind else { + return CalleeLattice::Dynamic; + }; + + if !matches!(decl.output, Ty::Arrow(_)) { + return CalleeLattice::Dynamic; + } + + let (body_block_id, body_input) = match &decl.implementation { + CallableImpl::Spec(spec_impl) => ( + spec_impl.body.block, + spec_impl.body.input.unwrap_or(decl.input), + ), + CallableImpl::SimulatableIntrinsic(spec_decl) => { + (spec_decl.block, spec_decl.input.unwrap_or(decl.input)) + } + CallableImpl::Intrinsic => return CalleeLattice::Dynamic, + }; + + let mut state = LocalState { + callable: FxHashMap::default(), + exprs: FxHashMap::default(), + }; + seed_param_bindings_from_call( + pkg, + store, + caller_locals, + &mut state, + body_input, + args_expr_id, + package_id, + ); + analyze_block_flow(pkg, store, body_block_id, &mut state, package_id); + + let block = pkg.get_block(body_block_id); + let Some(&stmt_id) = block.stmts.last() else { + return CalleeLattice::Dynamic; + }; + let stmt = pkg.get_stmt(stmt_id); + let return_expr_id = match &stmt.kind { + StmtKind::Expr(return_expr_id) => *return_expr_id, + StmtKind::Semi(expr_id) if matches!(pkg.get_expr(*expr_id).kind, ExprKind::Return(_)) => { + let ExprKind::Return(inner_expr_id) = pkg.get_expr(*expr_id).kind else { + unreachable!("guarded above") + }; + inner_expr_id + } + _ => return CalleeLattice::Dynamic, + }; + + materialize_capture_exprs_from_state( + pkg, + &state, + resolve_callee( + pkg, + store, + &state, + return_expr_id, + depth + 1, + allow_scoped_capture_exprs, + scoped_capture_vars, + package_id, + ), + ) +} + +/// Materializes `CapturedVar::expr` fields for each capture appearing in a +/// `CalleeLattice` by looking up the capture's defining expression in the +/// current `LocalState` so rewrite can re-emit the captures as arguments. +fn materialize_capture_exprs_from_state( + pkg: &Package, + state: &LocalState, + resolved: CalleeLattice, +) -> CalleeLattice { + match resolved { + CalleeLattice::Single(concrete) => { + CalleeLattice::Single(materialize_capture_exprs_in_callable(pkg, state, concrete)) + } + CalleeLattice::Multi(entries) => CalleeLattice::Multi( + entries + .into_iter() + .map(|(concrete, condition)| { + ( + materialize_capture_exprs_in_callable(pkg, state, concrete), + condition, + ) + }) + .collect(), + ), + other => other, + } +} + +/// Walks every reaching lattice entry recorded for the callables in a +/// reachable item set and calls [`materialize_capture_exprs_from_state`] +/// for each one so the final `LatticeStates` exposes capture expressions. +fn materialize_capture_exprs_in_callable( + pkg: &Package, + state: &LocalState, + concrete: ConcreteCallable, +) -> ConcreteCallable { + match concrete { + ConcreteCallable::Closure { + target, + mut captures, + functor, + } => { + for capture in &mut captures { + if capture.expr.is_none() { + capture.expr = resolve_capture_expr_from_state(pkg, state, capture.var); + } + } + + ConcreteCallable::Closure { + target, + captures, + functor, + } + } + other => other, + } +} + +/// Resolves the defining expression for a captured local by consulting the +/// flow-sensitive `LocalState::exprs` map populated during analysis. +fn resolve_capture_expr_from_state( + pkg: &Package, + state: &LocalState, + var: LocalVarId, +) -> Option { + let mut current = var; + + for _ in 0..MAX_RESOLVE_DEPTH { + let &expr_id = state.exprs.get(¤t)?; + let expr = pkg.get_expr(expr_id); + if let ExprKind::Var(Res::Local(next_var), _) = &expr.kind + && *next_var != current + && state.exprs.contains_key(next_var) + { + current = *next_var; + continue; + } + + return Some(expr_id); + } + + None +} + +/// Seeds the callable-flow lattice for a HOF with the concrete callables +/// bound to its arrow parameters at a specific call site, enabling +/// reaching-def analysis to track parameter-forwarding chains. +fn seed_param_bindings_from_call( + pkg: &Package, + store: &PackageStore, + caller_locals: &LocalState, + state: &mut LocalState, + pat_id: PatId, + arg_expr_id: ExprId, + package_id: PackageId, +) { + let pat = pkg.get_pat(pat_id); + match &pat.kind { + PatKind::Bind(ident) => { + state.exprs.insert(ident.id, arg_expr_id); + if matches!(pat.ty, Ty::Arrow(_)) { + let lattice = resolve_callee( + pkg, + store, + caller_locals, + arg_expr_id, + 0, + true, + &FxHashSet::default(), + package_id, + ); + state.callable.insert(ident.id, lattice); + } + } + PatKind::Tuple(sub_pats) => { + let arg_expr = pkg.get_expr(arg_expr_id); + if let ExprKind::Tuple(arg_elems) = &arg_expr.kind + && sub_pats.len() == arg_elems.len() + { + for (&sub_pat_id, &arg_elem_id) in sub_pats.iter().zip(arg_elems.iter()) { + seed_param_bindings_from_call( + pkg, + store, + caller_locals, + state, + sub_pat_id, + arg_elem_id, + package_id, + ); + } + } + } + PatKind::Discard => {} + } +} + +/// Applies an outer functor application to a resolved callable. +fn apply_outer_functor_cc(resolved: ConcreteCallable, outer: FunctorApp) -> ConcreteCallable { + match resolved { + ConcreteCallable::Global { item_id, functor } => ConcreteCallable::Global { + item_id, + functor: compose_functors(&outer, &functor), + }, + ConcreteCallable::Closure { + target, + captures, + functor, + } => ConcreteCallable::Closure { + target, + captures, + functor: compose_functors(&outer, &functor), + }, + ConcreteCallable::Dynamic => ConcreteCallable::Dynamic, + } +} + +/// Applies an outer functor application to all entries in a lattice element. +fn apply_outer_functor_lattice(resolved: CalleeLattice, outer: FunctorApp) -> CalleeLattice { + if outer == FunctorApp::default() { + return resolved; + } + match resolved { + CalleeLattice::Single(cc) => CalleeLattice::Single(apply_outer_functor_cc(cc, outer)), + CalleeLattice::Multi(entries) => CalleeLattice::Multi( + entries + .into_iter() + .map(|(cc, cond)| (apply_outer_functor_cc(cc, outer), cond)) + .collect(), + ), + other => other, + } +} + +/// Resolves a field access expression to the initialiser `ExprId` of that +/// field within a struct construction. Traces through immutable locals and +/// nested field accesses to locate the struct construction site. +fn resolve_struct_field( + pkg: &Package, + locals: &LocalState, + inner_expr_id: ExprId, + path: &FieldPath, + depth: usize, +) -> Option { + if depth > MAX_RESOLVE_DEPTH { + return None; + } + let inner_expr = pkg.get_expr(inner_expr_id); + match &inner_expr.kind { + ExprKind::Tuple(elements) => { + let (&field_index, rest) = path.indices.split_first()?; + let &field_expr_id = elements.get(field_index)?; + if rest.is_empty() { + Some(field_expr_id) + } else { + resolve_struct_field( + pkg, + locals, + field_expr_id, + &FieldPath { + indices: rest.to_vec(), + }, + depth + 1, + ) + } + } + ExprKind::Struct(_, _, fields) => extract_field_value(fields, path), + ExprKind::Call(_, args_id) => resolve_struct_field(pkg, locals, *args_id, path, depth + 1), + ExprKind::Var(Res::Local(var), _) => { + let &init_id = locals.exprs.get(var)?; + resolve_struct_field(pkg, locals, init_id, path, depth + 1) + } + ExprKind::Field(nested_inner_id, Field::Path(nested_path)) => { + // Two-level field access: resolve the outer field to get the inner + // struct expression, then resolve the target field within that. + let intermediate_id = + resolve_struct_field(pkg, locals, *nested_inner_id, nested_path, depth + 1)?; + resolve_struct_field(pkg, locals, intermediate_id, path, depth + 1) + } + _ => None, + } +} + +/// Resolves a single `Index(array, index)` expression to the concrete +/// callable at the indexed position when both the array and index are +/// statically known. +fn resolve_indexed_array_element( + pkg: &Package, + locals: &LocalState, + array_expr_id: ExprId, + index_expr_id: ExprId, + depth: usize, +) -> Option { + if depth > MAX_RESOLVE_DEPTH { + return None; + } + + let index = usize::try_from(resolve_static_int_expr( + pkg, + locals, + index_expr_id, + depth + 1, + )?) + .ok()?; + resolve_array_element_at_index(pkg, locals, array_expr_id, index, depth + 1) +} + +/// Resolves an `Index(array, index)` where the array is known but the +/// index may vary, returning a `CalleeLattice` of all statically possible +/// callables keyed against each index value. +#[allow(clippy::too_many_arguments)] +fn resolve_indexed_callable_candidates( + pkg: &Package, + store: &PackageStore, + locals: &LocalState, + array_expr_id: ExprId, + depth: usize, + allow_scoped_capture_exprs: bool, + scoped_capture_vars: &FxHashSet, + package_id: PackageId, +) -> Option> { + let element_expr_ids = resolve_array_elements(pkg, locals, array_expr_id, depth + 1)?; + let mut candidates = Vec::new(); + + for elem_expr_id in element_expr_ids { + let resolved = resolve_callee( + pkg, + store, + locals, + elem_expr_id, + depth + 1, + allow_scoped_capture_exprs, + scoped_capture_vars, + package_id, + ); + + match resolved { + CalleeLattice::Single(callable) => { + if !candidates.contains(&callable) { + candidates.push(callable); + } + } + CalleeLattice::Multi(entries) => { + for (callable, condition) in entries { + if condition.is_some() { + return None; + } + if !candidates.contains(&callable) { + candidates.push(callable); + } + } + } + CalleeLattice::Bottom | CalleeLattice::Dynamic => return None, + } + + if candidates.len() > super::types::MULTI_CAP { + return None; + } + } + + (!candidates.is_empty()).then_some(candidates) +} + +/// Resolves an array-literal expression to the concrete callables stored in +/// each element slot, yielding `None` when any element is not statically +/// known. +fn resolve_array_elements( + pkg: &Package, + locals: &LocalState, + expr_id: ExprId, + depth: usize, +) -> Option> { + if depth > MAX_RESOLVE_DEPTH { + return None; + } + + let expr = pkg.get_expr(expr_id); + match &expr.kind { + ExprKind::Array(elements) | ExprKind::ArrayLit(elements) | ExprKind::Tuple(elements) => { + Some(elements.clone()) + } + ExprKind::Var(Res::Local(var), _) => locals + .exprs + .get(var) + .and_then(|&init_expr_id| resolve_array_elements(pkg, locals, init_expr_id, depth + 1)), + ExprKind::Block(block_id) => { + let block = pkg.get_block(*block_id); + let stmt_id = *block.stmts.last()?; + let stmt = pkg.get_stmt(stmt_id); + let tail_expr_id = match &stmt.kind { + StmtKind::Expr(expr_id) | StmtKind::Semi(expr_id) => *expr_id, + _ => return None, + }; + resolve_array_elements(pkg, locals, tail_expr_id, depth + 1) + } + ExprKind::Return(inner_expr_id) => { + resolve_array_elements(pkg, locals, *inner_expr_id, depth + 1) + } + ExprKind::Field(inner_expr_id, Field::Path(path)) => { + let field_value_id = + resolve_struct_field(pkg, locals, *inner_expr_id, path, depth + 1)?; + resolve_array_elements(pkg, locals, field_value_id, depth + 1) + } + _ => None, + } +} + +/// Resolves the element at a specific static index within an array-literal +/// expression (after [`resolve_array_elements`] has resolved each slot). +fn resolve_array_element_at_index( + pkg: &Package, + locals: &LocalState, + expr_id: ExprId, + index: usize, + depth: usize, +) -> Option { + if depth > MAX_RESOLVE_DEPTH { + return None; + } + + let expr = pkg.get_expr(expr_id); + match &expr.kind { + ExprKind::Array(elements) | ExprKind::ArrayLit(elements) | ExprKind::Tuple(elements) => { + elements.get(index).copied() + } + ExprKind::Var(Res::Local(var), _) => locals.exprs.get(var).and_then(|&init_expr_id| { + resolve_array_element_at_index(pkg, locals, init_expr_id, index, depth + 1) + }), + ExprKind::Block(block_id) => { + let block = pkg.get_block(*block_id); + let stmt_id = *block.stmts.last()?; + let stmt = pkg.get_stmt(stmt_id); + let tail_expr_id = match &stmt.kind { + StmtKind::Expr(expr_id) | StmtKind::Semi(expr_id) => *expr_id, + _ => return None, + }; + resolve_array_element_at_index(pkg, locals, tail_expr_id, index, depth + 1) + } + ExprKind::Return(inner_expr_id) => { + resolve_array_element_at_index(pkg, locals, *inner_expr_id, index, depth + 1) + } + ExprKind::Field(inner_expr_id, Field::Path(path)) => { + let field_value_id = + resolve_struct_field(pkg, locals, *inner_expr_id, path, depth + 1)?; + resolve_array_element_at_index(pkg, locals, field_value_id, index, depth + 1) + } + _ => None, + } +} + +/// Attempts to reduce an expression to a compile-time integer value so that +/// indexed lookups can locate their source element statically. +fn resolve_static_int_expr( + pkg: &Package, + locals: &LocalState, + expr_id: ExprId, + depth: usize, +) -> Option { + if depth > MAX_RESOLVE_DEPTH { + return None; + } + + let expr = pkg.get_expr(expr_id); + match &expr.kind { + ExprKind::Lit(Lit::Int(value)) => Some(*value), + ExprKind::Var(Res::Local(var), _) => locals.exprs.get(var).and_then(|&init_expr_id| { + resolve_static_int_expr(pkg, locals, init_expr_id, depth + 1) + }), + ExprKind::Block(block_id) => { + let block = pkg.get_block(*block_id); + let stmt_id = *block.stmts.last()?; + let stmt = pkg.get_stmt(stmt_id); + let tail_expr_id = match &stmt.kind { + StmtKind::Expr(expr_id) | StmtKind::Semi(expr_id) => *expr_id, + _ => return None, + }; + resolve_static_int_expr(pkg, locals, tail_expr_id, depth + 1) + } + ExprKind::Return(inner_expr_id) => { + resolve_static_int_expr(pkg, locals, *inner_expr_id, depth + 1) + } + ExprKind::UnOp(UnOp::Neg, inner_expr_id) => { + resolve_static_int_expr(pkg, locals, *inner_expr_id, depth + 1).map(std::ops::Neg::neg) + } + _ => None, + } +} + +/// Extracts the value `ExprId` for a field from a struct construction's field +/// assignments by matching on the first index of the access path. +fn extract_field_value(fields: &[FieldAssign], path: &FieldPath) -> Option { + let target_index = path.indices.first()?; + for fa in fields { + if let Field::Path(fa_path) = &fa.field + && fa_path.indices.first() == Some(target_index) + { + return Some(fa.value); + } + } + None +} + +/// Resolves the types of captured variables in a closure expression. +pub(super) fn resolve_captures( + pkg: &Package, + locals: &LocalState, + captured_vars: &[LocalVarId], + scoped_capture_vars: &FxHashSet, +) -> Option> { + captured_vars + .iter() + .map(|&var| { + let ty = find_local_var_type(pkg, locals, var)?; + let expr = resolve_scoped_capture_expr(pkg, locals, var, scoped_capture_vars); + Some(CapturedVar { var, ty, expr }) + }) + .collect() +} + +/// Resolves a capture expression by walking the enclosing block scope and +/// its visible local bindings, used when the straightforward +/// [`resolve_capture_expr_from_state`] lookup cannot see the binding. +fn resolve_scoped_capture_expr( + pkg: &Package, + locals: &LocalState, + var: LocalVarId, + scoped_capture_vars: &FxHashSet, +) -> Option { + if !scoped_capture_vars.contains(&var) { + return None; + } + + let mut current = var; + for _ in 0..MAX_RESOLVE_DEPTH { + let &expr_id = locals.exprs.get(¤t)?; + let expr = pkg.get_expr(expr_id); + if let ExprKind::Var(Res::Local(next_var), _) = &expr.kind + && *next_var != current + && scoped_capture_vars.contains(next_var) + { + current = *next_var; + continue; + } + + return Some(expr_id); + } + + None +} + +/// Collects all local variables bound within a block (recursively through +/// statements and nested blocks) into `bound`, used to scope capture +/// resolution. +fn collect_block_local_bindings( + pkg: &Package, + block_id: BlockId, + bound: &mut FxHashSet, +) { + let block = pkg.get_block(block_id); + for stmt_id in &block.stmts { + let stmt = pkg.get_stmt(*stmt_id); + if let StmtKind::Local(_, pat_id, _) = stmt.kind { + collect_pat_local_bindings(pkg, pat_id, bound); + } + } +} + +/// Collects every local-variable binding introduced by a pattern into +/// `bound`, recursing into tuple patterns. +fn collect_pat_local_bindings(pkg: &Package, pat_id: PatId, bound: &mut FxHashSet) { + let pat = pkg.get_pat(pat_id); + match &pat.kind { + PatKind::Bind(ident) => { + bound.insert(ident.id); + } + PatKind::Discard => {} + PatKind::Tuple(pats) => { + for &sub_pat_id in pats { + collect_pat_local_bindings(pkg, sub_pat_id, bound); + } + } + } +} + +/// Finds the type of a local variable by looking up its initialiser expression. +/// Falls back to a full pattern scan when the variable is not in the +/// immutable-locals map (e.g. function parameters or outer-scope bindings). +fn find_local_var_type(pkg: &Package, locals: &LocalState, var: LocalVarId) -> Option { + if let Some(&init_expr_id) = locals.exprs.get(&var) { + Some(pkg.get_expr(init_expr_id).ty.clone()) + } else { + // The variable may be a function parameter or from an outer scope not + // tracked in the immutable-locals map. Scan all patterns as a fallback. + find_var_type_in_pats(pkg, var) + } +} + +/// Scans all patterns in a package to find the type of a given `LocalVarId`. +/// +/// Returns `None` if no binding pattern is found. Valid FIR gives every +/// `LocalVarId` a corresponding binding pattern, but returning `None` lets +/// callers degrade analysis for malformed or partially transformed input +/// instead of panicking. +fn find_var_type_in_pats(pkg: &Package, var: LocalVarId) -> Option { + for pat in pkg.pats.values() { + if let PatKind::Bind(ident) = &pat.kind + && ident.id == var + { + return Some(pat.ty.clone()); + } + } + None +} + +/// Builds flow-sensitive local variable state by performing a single forward +/// pass over the callable's body. +/// +/// For callable-typed locals, the analysis tracks reaching definitions through +/// `set` assignments, forks state at `if`/`else` branches, and conservatively +/// marks mutable callable vars assigned inside `while` loops as `Dynamic`. +/// +/// For all immutable locals, the raw `ExprId` binding is also recorded for +/// struct field resolution and type look-ups. +fn build_callable_flow_state( + pkg: &Package, + store: &PackageStore, + callable_impl: &CallableImpl, + package_id: PackageId, +) -> LocalState { + let mut state = LocalState { + callable: FxHashMap::default(), + exprs: FxHashMap::default(), + }; + match callable_impl { + CallableImpl::Intrinsic => {} + CallableImpl::Spec(spec_impl) => { + analyze_spec_flow(pkg, store, spec_impl, &mut state, package_id); + } + CallableImpl::SimulatableIntrinsic(spec_decl) => { + analyze_block_flow(pkg, store, spec_decl.block, &mut state, package_id); + } + } + state +} + +/// Runs callable-flow analysis over a single `SpecImpl`, merging the +/// resulting per-variable lattice with the caller-provided accumulator. +fn analyze_spec_flow( + pkg: &Package, + store: &PackageStore, + spec_impl: &SpecImpl, + state: &mut LocalState, + package_id: PackageId, +) { + analyze_block_flow(pkg, store, spec_impl.body.block, state, package_id); + for spec in functored_specs(spec_impl) { + analyze_block_flow(pkg, store, spec.block, state, package_id); + } +} + +/// Walks a block's statements, propagating callable-flow lattice updates +/// top-down so conditional joins preserve per-branch condition tags. +fn analyze_block_flow( + pkg: &Package, + store: &PackageStore, + block_id: BlockId, + state: &mut LocalState, + package_id: PackageId, +) { + let block = pkg.get_block(block_id); + for &stmt_id in &block.stmts { + let stmt = pkg.get_stmt(stmt_id); + analyze_stmt_flow(pkg, store, &stmt.kind, state, package_id); + } +} + +/// Updates the callable-flow lattice for a single statement (local +/// bindings, assignments, and expression statements) before recursing into +/// nested blocks. +fn analyze_stmt_flow( + pkg: &Package, + store: &PackageStore, + kind: &StmtKind, + state: &mut LocalState, + package_id: PackageId, +) { + match kind { + StmtKind::Local(Mutability::Immutable, pat_id, init_expr_id) => { + // Record ExprId bindings for all immutable locals. + collect_bindings_from_pat(pkg, *pat_id, *init_expr_id, &mut state.exprs); + // For callable-typed bindings, resolve and store in lattice. + bind_callable_pat(pkg, store, state, *pat_id, *init_expr_id, package_id); + analyze_expr_flow(pkg, store, *init_expr_id, state, package_id); + } + StmtKind::Local(Mutability::Mutable, pat_id, init_expr_id) => { + bind_callable_pat(pkg, store, state, *pat_id, *init_expr_id, package_id); + analyze_expr_flow(pkg, store, *init_expr_id, state, package_id); + } + StmtKind::Expr(e) | StmtKind::Semi(e) => { + analyze_expr_flow(pkg, store, *e, state, package_id); + } + StmtKind::Item(_) => {} + } +} + +/// Binds callable-typed variables from a pattern to their resolved +/// `CalleeLattice` values. +fn bind_callable_pat( + pkg: &Package, + store: &PackageStore, + state: &mut LocalState, + pat_id: qsc_fir::fir::PatId, + init_expr_id: ExprId, + package_id: PackageId, +) { + let pat = pkg.get_pat(pat_id); + match &pat.kind { + PatKind::Bind(ident) => { + if matches!(pat.ty, Ty::Arrow(_)) { + let lattice = resolve_callee( + pkg, + store, + state, + init_expr_id, + 0, + true, + &FxHashSet::default(), + package_id, + ); + state.callable.insert(ident.id, lattice); + } + } + PatKind::Tuple(sub_pats) => { + let init_expr = pkg.get_expr(init_expr_id); + if let ExprKind::Tuple(init_elems) = &init_expr.kind + && sub_pats.len() == init_elems.len() + { + for (&sub_pat_id, &elem_expr_id) in sub_pats.iter().zip(init_elems.iter()) { + bind_callable_pat(pkg, store, state, sub_pat_id, elem_expr_id, package_id); + } + } else { + // Non-tuple init (e.g., ExprKind::Index from for-loop desugaring). + // Resolve the init through variable indirection first. + let resolved_init_id = resolve_through_vars(pkg, state, init_expr_id); + let resolved_init = pkg.get_expr(resolved_init_id); + + if let ExprKind::Tuple(init_elems) = &resolved_init.kind + && sub_pats.len() == init_elems.len() + { + // Resolved to a literal tuple — recurse element-wise. + for (&sub_pat_id, &elem_expr_id) in sub_pats.iter().zip(init_elems.iter()) { + bind_callable_pat(pkg, store, state, sub_pat_id, elem_expr_id, package_id); + } + } else if let ExprKind::Index(array_expr_id, _) = &resolved_init.kind { + // Dynamic array index: resolve all array elements and extract + // per-field callables for each arrow-typed sub-pattern. + bind_callable_pats_from_indexed_array( + pkg, + store, + state, + sub_pats, + *array_expr_id, + package_id, + ); + } + } + } + PatKind::Discard => {} + } +} + +/// Follows `ExprKind::Var(Res::Local(var))` through `state.exprs` to find +/// the underlying expression, stopping when no further indirection exists. +fn resolve_through_vars(pkg: &Package, state: &LocalState, expr_id: ExprId) -> ExprId { + let expr = pkg.get_expr(expr_id); + if let ExprKind::Var(Res::Local(var), _) = &expr.kind + && let Some(&init_id) = state.exprs.get(var) + { + return resolve_through_vars(pkg, state, init_id); + } + expr_id +} + +/// Binds callable-typed sub-patterns from a tuple pattern where the init +/// expression is `array[dynamic_index]`. Resolves all array elements, +/// extracts the field at each sub-pattern position, and joins the resolved +/// callables into a `CalleeLattice`. +fn bind_callable_pats_from_indexed_array( + pkg: &Package, + store: &PackageStore, + state: &mut LocalState, + sub_pats: &[PatId], + array_expr_id: ExprId, + package_id: PackageId, +) { + // Resolve the array to its element ExprIds. + let Some(array_elem_ids) = resolve_array_elements(pkg, state, array_expr_id, 0) else { + return; // Cannot resolve array — leave sub-patterns unbound (conservative). + }; + + for (field_idx, &sub_pat_id) in sub_pats.iter().enumerate() { + let sub_pat = pkg.get_pat(sub_pat_id); + let PatKind::Bind(ident) = &sub_pat.kind else { + continue; // Skip Discard and nested Tuple for now. + }; + if !matches!(sub_pat.ty, Ty::Arrow(_)) { + continue; // Only bind arrow-typed locals. + } + + // Collect the callable at field_idx from each array element tuple. + let mut lattice = CalleeLattice::Bottom; + for &elem_expr_id in &array_elem_ids { + let elem_expr = pkg.get_expr(elem_expr_id); + if let ExprKind::Tuple(fields) = &elem_expr.kind + && let Some(&field_expr_id) = fields.get(field_idx) + { + let field_lattice = resolve_callee( + pkg, + store, + state, + field_expr_id, + 0, + true, + &FxHashSet::default(), + package_id, + ); + lattice = lattice.join(field_lattice); + } + } + + if !matches!(lattice, CalleeLattice::Bottom) { + state.callable.insert(ident.id, lattice); + } + } +} + +/// Walks an expression for control-flow structures that affect reaching +/// definitions: assignments, blocks, conditionals, and loops. +fn analyze_expr_flow( + pkg: &Package, + store: &PackageStore, + expr_id: ExprId, + state: &mut LocalState, + package_id: PackageId, +) { + let expr = pkg.get_expr(expr_id); + match &expr.kind { + ExprKind::Assign(lhs_id, rhs_id) => { + let lhs = pkg.get_expr(*lhs_id); + if let ExprKind::Var(Res::Local(var), _) = &lhs.kind + && state.callable.contains_key(var) + { + let lattice = resolve_callee( + pkg, + store, + state, + *rhs_id, + 0, + true, + &FxHashSet::default(), + package_id, + ); + state.callable.insert(*var, lattice); + } + } + ExprKind::Block(block_id) => { + analyze_block_flow(pkg, store, *block_id, state, package_id); + } + ExprKind::If(cond, body, otherwise) => { + analyze_expr_flow(pkg, store, *cond, state, package_id); + // Fork: save callable state before branches. + let pre_if = state.callable.clone(); + // Analyze true branch. + analyze_expr_flow(pkg, store, *body, state, package_id); + let true_state = state.callable.clone(); + // Restore pre-if state and analyze false branch. + state.callable = pre_if; + if let Some(else_expr) = otherwise { + analyze_expr_flow(pkg, store, *else_expr, state, package_id); + } + // Join: merge true and false branch states per variable, + // tagging entries with the condition for branch splitting. + let false_state = std::mem::take(&mut state.callable); + state.callable = join_callable_states_with_condition(&true_state, &false_state, *cond); + } + ExprKind::While(cond, block_id) => { + analyze_expr_flow(pkg, store, *cond, state, package_id); + // Conservative: mark all mutable callable vars assigned inside + // the loop body as Dynamic. + let assigned = collect_assigned_vars_in_block(pkg, *block_id); + for var in &assigned { + if state.callable.contains_key(var) { + state.callable.insert(*var, CalleeLattice::Dynamic); + } + } + // Analyze the body for nested let bindings. Restore pre-existing + // callable entries to their pre-loop values, but keep NEW entries + // added by loop-body analysis (loop-local immutable bindings). + let pre_loop_callable = state.callable.clone(); + analyze_block_flow(pkg, store, *block_id, state, package_id); + for (var, lattice) in pre_loop_callable { + state.callable.insert(var, lattice); + } + } + _ => {} + } +} + +/// Joins two callable-state maps by performing per-variable lattice join +/// with an associated condition from an if/else branch. +fn join_callable_states_with_condition( + true_state: &FxHashMap, + false_state: &FxHashMap, + condition: ExprId, +) -> FxHashMap { + let mut result = FxHashMap::default(); + let all_vars: FxHashSet = true_state + .keys() + .chain(false_state.keys()) + .copied() + .collect(); + for var in all_vars { + let a_val = true_state + .get(&var) + .cloned() + .unwrap_or(CalleeLattice::Bottom); + let b_val = false_state + .get(&var) + .cloned() + .unwrap_or(CalleeLattice::Bottom); + result.insert(var, a_val.join_with_condition(b_val, condition)); + } + result +} + +/// Collects all `LocalVarId`s that are targets of `Assign` expressions +/// within a block (recursively including nested blocks and control flow). +fn collect_assigned_vars_in_block(pkg: &Package, block_id: BlockId) -> Vec { + let mut vars = Vec::new(); + collect_assigned_vars_block(pkg, block_id, &mut vars); + vars +} + +/// Collects every `LocalVarId` assigned within a block (mutable update or +/// `Assign`), accumulating into `vars` so branch joins can invalidate +/// stale lattice entries. +fn collect_assigned_vars_block(pkg: &Package, block_id: BlockId, vars: &mut Vec) { + let block = pkg.get_block(block_id); + for &stmt_id in &block.stmts { + let stmt = pkg.get_stmt(stmt_id); + match &stmt.kind { + StmtKind::Expr(e) | StmtKind::Semi(e) | StmtKind::Local(_, _, e) => { + collect_assigned_vars_expr(pkg, *e, vars); + } + StmtKind::Item(_) => {} + } + } +} + +/// Collects every `LocalVarId` assigned within an expression subtree, +/// recursing through nested blocks, conditionals, and loops. +fn collect_assigned_vars_expr(pkg: &Package, expr_id: ExprId, vars: &mut Vec) { + let expr = pkg.get_expr(expr_id); + match &expr.kind { + ExprKind::Assign(lhs_id, _) => { + let lhs = pkg.get_expr(*lhs_id); + if let ExprKind::Var(Res::Local(var), _) = &lhs.kind { + vars.push(*var); + } + } + ExprKind::Block(block_id) | ExprKind::While(_, block_id) => { + collect_assigned_vars_block(pkg, *block_id, vars); + } + ExprKind::If(_, body, otherwise) => { + collect_assigned_vars_expr(pkg, *body, vars); + if let Some(e) = otherwise { + collect_assigned_vars_expr(pkg, *e, vars); + } + } + _ => {} + } +} + +/// Extracts bindings from a pattern. For `Bind(ident)` patterns, records +/// `ident.id → init_expr_id`. For `Tuple` patterns, we cannot easily +/// split the init expression, so we skip those. +fn collect_bindings_from_pat( + pkg: &Package, + pat_id: qsc_fir::fir::PatId, + init_expr_id: ExprId, + map: &mut FxHashMap, +) { + let pat = pkg.get_pat(pat_id); + match &pat.kind { + PatKind::Bind(ident) => { + map.insert(ident.id, init_expr_id); + } + PatKind::Tuple(sub_pats) => { + // If the init is also a tuple expression, match element-wise. + let init_expr = pkg.get_expr(init_expr_id); + if let ExprKind::Tuple(init_elems) = &init_expr.kind + && sub_pats.len() == init_elems.len() + { + for (&sub_pat_id, &elem_expr_id) in sub_pats.iter().zip(init_elems.iter()) { + collect_bindings_from_pat(pkg, sub_pat_id, elem_expr_id, map); + } + } + } + PatKind::Discard => {} + } +} diff --git a/source/compiler/qsc_fir_transforms/src/defunctionalize/prepass.rs b/source/compiler/qsc_fir_transforms/src/defunctionalize/prepass.rs new file mode 100644 index 0000000000..f364b574e5 --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/defunctionalize/prepass.rs @@ -0,0 +1,398 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! Pre-pass rewrites before collecting call sites for defunctionalization. +//! These rewrites are not strictly necessary for correctness, but they +//! simplify the analysis by eliminating certain patterns of indirection and +//! exposing more direct call sites. They are run before collecting call sites +//! and performing the lattice analysis. +//! +//! # Responsibilities +//! +//! - Run the single-use local promotion that replaces single-use immutable +//! callable locals with direct references to their initializer (via +//! [`promote_single_use_callable_locals`]). +//! - Run the identity-closure peephole that replaces `(args) => f(args)` +//! closures with direct references to `f` (via +//! [`identity_closure_peephole`]). +//! + +use qsc_fir::fir::{ + CallableImpl, ExprId, ExprKind, ItemKind, LocalVarId, Mutability, Package, PackageId, + PackageLookup, PackageStore, PatKind, Res, StmtKind, UnOp, +}; +use qsc_fir::ty::Ty; +use rustc_hash::FxHashMap; + +/// Runs pre-pass rewrites before collecting call sites for defunctionalization. See +/// [`promote_single_use_callable_locals`] and [`identity_closure_peephole`] for details. +/// +/// Only expressions in `reachable_expr_ids` are scanned for promotion candidates +/// and identity-closure patterns, restricting analysis to entry-reachable code. +pub(super) fn run(store: &mut PackageStore, package_id: PackageId, reachable_expr_ids: &[ExprId]) { + // Before collecting call sites, runs pre-pass rewrites: + // 1. Promotes single-use immutable callable locals to direct item references. + // 2. Replaces identity closures `(args) => f(args)` with direct references to `f`. + promote_single_use_callable_locals(store, package_id, reachable_expr_ids); + identity_closure_peephole(store, package_id, reachable_expr_ids); +} + +/// Promotes single-use immutable callable locals whose initializer is a simple +/// item reference. For example, `let op = H; Apply(op, q)` is rewritten to +/// `Apply(H, q)`, eliminating the indirection before analysis runs. +/// +/// # Before +/// ```text +/// let op = H; // Local(pat, Var(Item(H))) +/// Apply(op, qubit); // Call(Apply, (Var(Local(op)), qubit)) +/// ``` +/// # After +/// ```text +/// let op = H; // binding still present (DCE removes later) +/// Apply(H, qubit); // Call(Apply, (Var(Item(H)), qubit)) +/// ``` +/// +/// # Mutations +/// - Rewrites `Expr.kind` at each single-use site from `Var(Local(..))` +/// to `Var(Item(..))` in place. +fn promote_single_use_callable_locals( + store: &mut PackageStore, + package_id: PackageId, + reachable_expr_ids: &[ExprId], +) { + let replacements = { + let pkg = store.get(package_id); + collect_single_use_promotions(pkg, reachable_expr_ids) + }; + + if !replacements.is_empty() { + let pkg = store.get_mut(package_id); + for (expr_id, new_kind) in replacements { + pkg.exprs + .get_mut(expr_id) + .expect("expression should exist") + .kind = new_kind; + } + } +} + +/// Scans all immutable local bindings whose initialiser is a simple item +/// reference (`Var(Res::Item(_))`), counts uses within reachable expressions, +/// and collects replacements for locals that are used exactly once. +fn collect_single_use_promotions( + pkg: &Package, + reachable_expr_ids: &[ExprId], +) -> Vec<(ExprId, ExprKind)> { + // find candidate immutable locals whose init is a simple item reference. + let mut candidates: FxHashMap = FxHashMap::default(); + for (_, stmt) in &pkg.stmts { + if let StmtKind::Local(Mutability::Immutable, pat_id, init_expr_id) = &stmt.kind { + let pat = pkg.get_pat(*pat_id); + if let PatKind::Bind(ident) = &pat.kind + && matches!(pat.ty, Ty::Arrow(_)) + { + let init_expr = pkg.get_expr(*init_expr_id); + if let ExprKind::Var(Res::Item(item_id), generic_args) = &init_expr.kind { + candidates.insert( + ident.id, + ExprKind::Var(Res::Item(*item_id), generic_args.clone()), + ); + } + } + } + } + + if candidates.is_empty() { + return Vec::new(); + } + + // exclude candidates that are captured by closures (within reachable code). + for &expr_id in reachable_expr_ids { + let expr = pkg.get_expr(expr_id); + if let ExprKind::Closure(captures, _) = &expr.kind { + for var in captures { + candidates.remove(var); + } + } + } + + if candidates.is_empty() { + return Vec::new(); + } + + // count uses and record use-site expression IDs (within reachable code). + let mut use_info: FxHashMap> = + candidates.keys().map(|&var| (var, Vec::new())).collect(); + + for &expr_id in reachable_expr_ids { + let expr = pkg.get_expr(expr_id); + if let ExprKind::Var(Res::Local(var), _) = &expr.kind + && let Some(uses) = use_info.get_mut(var) + { + uses.push(expr_id); + } + } + + // build replacements for single-use locals. + let mut replacements = Vec::new(); + for (var, uses) in &use_info { + if uses.len() == 1 { + replacements.push((uses[0], candidates[var].clone())); + } + } + + replacements +} + +/// Replaces identity closures `(args) => f(args)` with direct references to +/// the callee in the package's expressions. An identity closure is one whose +/// body is a single call that forwards all actual parameters in order to a +/// callee that is either a global item or a single captured variable. +/// +/// # Before +/// ```text +/// Closure([captures], target) // target body: (args) => callee(args) +/// ``` +/// # After (global callee) +/// ```text +/// Var(Item(callee_item)) // closure collapsed to direct item reference +/// ``` +/// # After (captured-local callee) +/// ```text +/// Var(Local(outer_var)) // closure collapsed to outer-scope local +/// ``` +/// # After (functor-wrapped callee) +/// ```text +/// UnOp(Functor(Adj), Var(Item(callee_item))) // functor chain preserved +/// ``` +/// +/// # Mutations +/// - Rewrites `Expr.kind` at each identity-closure site in place. +fn identity_closure_peephole( + store: &mut PackageStore, + package_id: PackageId, + reachable_expr_ids: &[ExprId], +) { + // Collect replacements using an immutable borrow. + let replacements = { + let pkg = store.get(package_id); + collect_identity_closures(pkg, reachable_expr_ids) + }; + + // Apply replacements using a mutable borrow. + if !replacements.is_empty() { + let pkg = store.get_mut(package_id); + for (expr_id, new_kind) in replacements { + pkg.exprs + .get_mut(expr_id) + .expect("expression should exist") + .kind = new_kind; + } + } +} + +/// Scans reachable expressions and collects `(ExprId, replacement ExprKind)` pairs +/// for identity closures. +fn collect_identity_closures( + pkg: &Package, + reachable_expr_ids: &[ExprId], +) -> Vec<(ExprId, ExprKind)> { + let mut replacements = Vec::new(); + + for &expr_id in reachable_expr_ids { + let expr = pkg.get_expr(expr_id); + if let ExprKind::Closure(captures, target) = &expr.kind { + replacements.extend(check_identity_closure(pkg, expr_id, captures, *target)); + } + } + + replacements +} + +/// Checks whether a closure is an identity wrapper `(args) => f(args)` or a +/// functor-wrapped identity `(args) => Adjoint f(args)` / +/// `(args) => Controlled f(args)`, and returns expression replacements that +/// collapse the closure to a direct reference (optionally functor-applied). +fn check_identity_closure( + pkg: &Package, + closure_expr_id: ExprId, + captures: &[LocalVarId], + target: qsc_fir::fir::LocalItemId, +) -> Vec<(ExprId, ExprKind)> { + // Get the closure's callable declaration. + let Some(item) = pkg.items.get(target) else { + return Vec::new(); + }; + let ItemKind::Callable(decl) = &item.kind else { + return Vec::new(); + }; + + // Only handle Spec implementations (not Intrinsic). + let body_block_id = match &decl.implementation { + CallableImpl::Spec(spec_impl) => spec_impl.body.block, + _ => return Vec::new(), + }; + + let block = pkg.get_block(body_block_id); + + // Body must have exactly one statement. + if block.stmts.len() != 1 { + return Vec::new(); + } + + let stmt = pkg.get_stmt(block.stmts[0]); + let call_expr_id = match &stmt.kind { + StmtKind::Semi(e) | StmtKind::Expr(e) => *e, + _ => return Vec::new(), + }; + + let call_expr = pkg.get_expr(call_expr_id); + let (callee_id, args_id) = match &call_expr.kind { + ExprKind::Call(callee, args) => (*callee, *args), + _ => return Vec::new(), + }; + + // Parse the callable's input pattern to separate capture params from actual params. + let Some(all_param_vars) = extract_flat_param_vars(pkg, decl.input) else { + return Vec::new(); + }; + let num_captures = captures.len(); + if all_param_vars.len() < num_captures { + return Vec::new(); + } + let capture_param_vars = &all_param_vars[..num_captures]; + let actual_param_vars = &all_param_vars[num_captures..]; + + // Must have at least one actual parameter to be a meaningful identity wrapper. + if actual_param_vars.is_empty() { + return Vec::new(); + } + + // Verify that args forward all actual params in order. + if !args_forward_params_in_order(pkg, args_id, actual_param_vars) { + return Vec::new(); + } + + // Ensure no capture params appear in the arguments. + if captures_appear_in_args(pkg, args_id, capture_param_vars) { + return Vec::new(); + } + + // Determine the replacement based on the callee expression. + let callee_expr = pkg.get_expr(callee_id); + match &callee_expr.kind { + // Callee is a captured local variable — replace with the enclosing scope's var. + ExprKind::Var(Res::Local(var), _) => { + let Some(capture_idx) = capture_param_vars.iter().position(|&v| v == *var) else { + return Vec::new(); + }; + vec![( + closure_expr_id, + ExprKind::Var(Res::Local(captures[capture_idx]), Vec::new()), + )] + } + // Callee is a global item — replace with the global reference. + ExprKind::Var(Res::Item(item_id), generic_args) => { + vec![( + closure_expr_id, + ExprKind::Var(Res::Item(*item_id), generic_args.clone()), + )] + } + // Callee is a functor-wrapped expression — replace closure with the functor + // application and rewrite the inner expression to reference the enclosing scope. + ExprKind::UnOp(UnOp::Functor(functor), inner_id) => { + let inner_expr = pkg.get_expr(*inner_id); + match &inner_expr.kind { + ExprKind::Var(Res::Local(var), _) => { + let Some(capture_idx) = capture_param_vars.iter().position(|&v| v == *var) + else { + return Vec::new(); + }; + vec![ + ( + *inner_id, + ExprKind::Var(Res::Local(captures[capture_idx]), Vec::new()), + ), + ( + closure_expr_id, + ExprKind::UnOp(UnOp::Functor(*functor), *inner_id), + ), + ] + } + ExprKind::Var(Res::Item(_), _) => { + // Inner expression already references the global item; only + // the closure expression needs replacing. + vec![( + closure_expr_id, + ExprKind::UnOp(UnOp::Functor(*functor), *inner_id), + )] + } + _ => Vec::new(), + } + } + _ => Vec::new(), + } +} + +/// Extracts a flat list of `LocalVarId`s from a pattern. Returns `None` if the +/// pattern contains discards that cannot be mapped to individual variables. +fn extract_flat_param_vars(pkg: &Package, pat_id: qsc_fir::fir::PatId) -> Option> { + let pat = pkg.get_pat(pat_id); + match &pat.kind { + PatKind::Bind(ident) => Some(vec![ident.id]), + PatKind::Tuple(sub_pats) => { + let mut variables = Vec::new(); + for &sub_pat_id in sub_pats { + variables.extend(extract_flat_param_vars(pkg, sub_pat_id)?); + } + Some(variables) + } + PatKind::Discard => None, + } +} + +/// Checks whether the args expression forwards exactly the given parameter +/// variables in order. Handles both single-variable and tuple cases. +fn args_forward_params_in_order( + pkg: &Package, + args_id: ExprId, + actual_param_vars: &[LocalVarId], +) -> bool { + extract_flat_arg_vars(pkg, args_id).is_some_and(|variables| variables == actual_param_vars) +} + +/// Extracts a flat list of `LocalVarId`s from an arguments expression. Returns `None` +/// if the expression is not a simple variable or tuple of variables (e.g. if it +/// contains discards, literals, or complex expressions). +fn extract_flat_arg_vars(pkg: &Package, args_id: ExprId) -> Option> { + let args_expr = pkg.get_expr(args_id); + match &args_expr.kind { + ExprKind::Var(Res::Local(var), _) => Some(vec![*var]), + ExprKind::Tuple(elements) => { + let mut variables = Vec::new(); + for &element_id in elements { + variables.extend(extract_flat_arg_vars(pkg, element_id)?); + } + Some(variables) + } + _ => None, + } +} + +/// Returns `true` if any of the capture parameter variables appear in the +/// arguments expression. +fn captures_appear_in_args( + pkg: &Package, + args_id: ExprId, + capture_param_vars: &[LocalVarId], +) -> bool { + if capture_param_vars.is_empty() { + return false; + } + match extract_flat_arg_vars(pkg, args_id) { + Some(variables) => variables + .iter() + .any(|variable| capture_param_vars.contains(variable)), + _ => true, // Conservatively assume captures may be used in complex expressions. + } +} diff --git a/source/compiler/qsc_fir_transforms/src/defunctionalize/rewrite.rs b/source/compiler/qsc_fir_transforms/src/defunctionalize/rewrite.rs new file mode 100644 index 0000000000..f054f1b381 --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/defunctionalize/rewrite.rs @@ -0,0 +1,3362 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! Rewrite phase of the defunctionalization pass. +//! +//! For each call site where a higher-order function is invoked with a concrete +//! callable argument, this module rewrites the call to invoke the specialized +//! callable directly, removes the callable argument from the call's argument +//! tuple, and threads closure captures as extra arguments when applicable. +//! +//! # Subsystems +//! +//! The module is organized into three cooperating subsystems: +//! +//! - **Dispatch synthesis** — synthesizes `if`/`else` chains that select a +//! specialized callee per reaching-definition branch for call sites whose +//! analysis produced a `Multi` lattice with branch conditions (see +//! [`synthesize_callsite_index_dispatch`], +//! [`synthesize_direct_index_dispatch`], and the +//! `synthesize_index_dispatch_plan` family). +//! - **Direct-call dispatch** — rewrites callee expressions, callee types, +//! and argument tuples so a HOF invocation becomes a direct call to the +//! specialized target (see [`rewrite_direct_call`], +//! [`rewrite_direct_callee`], [`rewrite_direct_closure_args`], and +//! `build_direct_global_callee_ty`). +//! - **Dead-local cleanup** — removes callable-typed locals whose only +//! remaining uses were direct-call rewrites, keeping `PostDefunc` clean +//! of arrow-typed residues (see the `prune_*` and +//! `remove_dead_callable_local_*` helpers). +//! +//! # Notes +//! +//! - A copy of the `apply_target_input_at_control_path` helper also lives +//! in [`super::specialize::apply_target_input_at_control_path`]. The copy +//! is retained so that specialize and rewrite can evolve their +//! controlled-layer handling independently without forcing a shared +//! abstraction boundary; update both copies in lockstep when +//! controlled-layer semantics change. + +use super::types::{ + AnalysisResult, CallSite, CallableParam, CapturedVar, ConcreteCallable, DirectCallSite, + SpecKey, peel_body_functors, +}; +use super::{build_spec_key, ty_contains_arrow}; +use crate::EMPTY_EXEC_RANGE; +use qsc_data_structures::functors::FunctorApp; +use qsc_data_structures::span::Span; +use qsc_fir::assigner::Assigner; +use qsc_fir::fir::{ + BinOp, Expr, ExprId, ExprKind, Field, Functor, ItemId, ItemKind, Lit, LocalItemId, LocalVarId, + Mutability, Package, PackageId, PackageLookup, PatId, PatKind, Res, StmtKind, UnOp, +}; +use qsc_fir::ty::{Arrow, Prim, Ty}; +use rustc_hash::{FxHashMap, FxHashSet}; + +/// Rewrites call sites in the target package so that higher-order calls are +/// replaced with direct calls to their specialized counterparts. +/// +/// For each call site with a matching specialization in `spec_map`: +/// - The callee expression is replaced with a reference to the specialized +/// callable. +/// - The callable argument is removed from the argument tuple. +/// - If the callable argument was a closure, its captured variables are +/// appended as extra arguments. +/// - The callee expression's type is updated to reflect the new signature. +#[allow(clippy::too_many_lines)] +pub(super) fn rewrite( + package: &mut Package, + package_id: PackageId, + analysis: &AnalysisResult, + spec_map: &FxHashMap, + assigner: &mut Assigner, +) { + let expr_owner_lookup = build_expr_owner_lookup(package); + let mut rewritten_callable_arg_locals = FxHashSet::default(); + + // Build a lookup from HOF LocalItemId → CallableParam. + let param_lookup: FxHashMap = { + let mut map = FxHashMap::default(); + for p in &analysis.callable_params { + map.entry(p.callable_id).or_insert(p); + } + map + }; + + // Group resolved call sites by call_expr_id so that multi-callee sites + // (from branch-split analysis) are handled together. + let mut grouped: FxHashMap> = + FxHashMap::default(); + + for call_site in &analysis.call_sites { + // Skip dynamic callables — they have no specialization. + if matches!(call_site.callable_arg, ConcreteCallable::Dynamic) { + continue; + } + + let spec_key = build_spec_key(call_site); + let Some(&spec_local_id) = spec_map.get(&spec_key) else { + continue; + }; + + let hof_local_id = call_site.hof_item_id.item; + let Some(¶m) = param_lookup.get(&hof_local_id) else { + continue; + }; + + grouped + .entry(call_site.call_expr_id) + .or_default() + .push((call_site, spec_local_id, param)); + } + + for (call_expr_id, entries) in &grouped { + if entries.len() == 1 { + let (call_site, spec_local_id, param) = entries[0]; + collect_rewritten_callable_arg_local( + package, + &expr_owner_lookup, + call_site.call_expr_id, + call_site.arg_expr_id, + &mut rewritten_callable_arg_locals, + ); + rewrite_one( + package, + package_id, + call_site, + param, + spec_local_id, + &expr_owner_lookup, + assigner, + ); + } else { + for (call_site, _, _) in entries { + collect_rewritten_callable_arg_local( + package, + &expr_owner_lookup, + call_site.call_expr_id, + call_site.arg_expr_id, + &mut rewritten_callable_arg_locals, + ); + } + branch_split_rewrite( + package, + package_id, + *call_expr_id, + entries, + &expr_owner_lookup, + assigner, + ); + } + } + + let mut grouped_direct: FxHashMap> = FxHashMap::default(); + for direct_call_site in &analysis.direct_call_sites { + grouped_direct + .entry(direct_call_site.call_expr_id) + .or_default() + .push(direct_call_site); + } + + for entries in grouped_direct.values() { + if entries.len() == 1 && entries[0].condition.is_none() { + rewrite_direct_call( + package, + package_id, + entries[0], + &expr_owner_lookup, + &mut rewritten_callable_arg_locals, + assigner, + ); + } else { + let call_expr_id = entries[0].call_expr_id; + let call_expr = package.get_expr(call_expr_id).clone(); + let ExprKind::Call(callee_id, _) = call_expr.kind else { + continue; + }; + + collect_rewritten_callable_arg_local( + package, + &expr_owner_lookup, + call_expr_id, + callee_id, + &mut rewritten_callable_arg_locals, + ); + branch_split_direct_call_rewrite( + package, + package_id, + call_expr_id, + entries, + &expr_owner_lookup, + assigner, + ); + } + } + + prune_dead_callable_arg_locals(package, &rewritten_callable_arg_locals); +} + +/// Rewrites a `DirectCallSite` whose callee was resolved to a specific +/// concrete callable into a direct invocation of that callable, pruning +/// the now-unused callee expression. +fn rewrite_direct_call( + package: &mut Package, + package_id: PackageId, + direct_call_site: &DirectCallSite, + expr_owner_lookup: &FxHashMap, + rewritten_callable_arg_locals: &mut FxHashSet<(LocalItemId, LocalVarId)>, + assigner: &mut Assigner, +) { + let call_expr = package.get_expr(direct_call_site.call_expr_id).clone(); + let ExprKind::Call(callee_id, args_id) = call_expr.kind else { + return; + }; + let (_, outer_functor) = peel_body_functors(package, callee_id); + let controlled_layers = usize::from(outer_functor.controlled); + let package_direct_lambda = match &direct_call_site.callable { + ConcreteCallable::Global { item_id, .. } if item_id.package == package_id => { + direct_lambda_packaged_input(package, item_id.item).is_some_and(|target_input| { + apply_target_input_at_control_path( + &package.get_expr(args_id).ty, + &target_input, + controlled_layers, + ) != package.get_expr(args_id).ty + }) + } + _ => false, + }; + + collect_rewritten_callable_arg_local( + package, + expr_owner_lookup, + direct_call_site.call_expr_id, + callee_id, + rewritten_callable_arg_locals, + ); + + let captures = match &direct_call_site.callable { + ConcreteCallable::Closure { captures, .. } => { + resolve_rewrite_captures(package, callee_id, captures) + } + _ => Vec::new(), + }; + + rewrite_direct_callee( + package, + package_id, + callee_id, + &direct_call_site.callable, + &captures, + controlled_layers, + assigner, + ); + if matches!(direct_call_site.callable, ConcreteCallable::Closure { .. }) + || package_direct_lambda + { + rewrite_direct_closure_args(package, args_id, &captures, controlled_layers, assigner); + } +} + +/// Rewrites a direct call whose callee has multiple possible concrete +/// values by synthesizing a condition-indexed dispatch that selects the +/// specialized callee matching the observed branch. +fn branch_split_direct_call_rewrite( + package: &mut Package, + package_id: PackageId, + call_expr_id: ExprId, + entries: &[&DirectCallSite], + expr_owner_lookup: &FxHashMap, + assigner: &mut Assigner, +) { + let orig_call = package.get_expr(call_expr_id).clone(); + let ExprKind::Call(orig_callee_id, orig_args_id) = orig_call.kind else { + return; + }; + let span = orig_call.span; + let result_ty = orig_call.ty.clone(); + + let mut conditioned: Vec<(&DirectCallSite, ExprId)> = Vec::new(); + let mut default = None; + for &entry in entries { + if let Some(condition) = entry.condition { + conditioned.push((entry, condition)); + } else if default.is_none() { + default = Some(entry); + } + } + + if conditioned.is_empty() + && entries.len() > 1 + && let Some((synthetic_conditioned, default_idx)) = synthesize_direct_index_dispatch( + package, + expr_owner_lookup, + call_expr_id, + entries, + span, + assigner, + ) + { + conditioned = synthetic_conditioned + .into_iter() + .map(|(entry_idx, condition)| (entries[entry_idx], condition)) + .collect(); + default = Some(entries[default_idx]); + } + + let default_entry = if let Some(entry) = default { + entry + } else { + if conditioned.is_empty() { + return; + } + conditioned.pop().expect("non-empty conditioned").0 + }; + + if conditioned.is_empty() { + let mut rewritten_callable_arg_locals = FxHashSet::default(); + rewrite_direct_call( + package, + package_id, + default_entry, + expr_owner_lookup, + &mut rewritten_callable_arg_locals, + assigner, + ); + return; + } + + let orig_callee = package.get_expr(orig_callee_id).clone(); + let orig_args = package.get_expr(orig_args_id).clone(); + + let else_call_id = create_direct_branch_call( + package, + package_id, + &orig_callee, + &orig_args, + span, + &result_ty, + default_entry, + assigner, + ); + + let mut current_else = else_call_id; + for (entry, cond_id) in conditioned.into_iter().rev() { + let branch_call_id = create_direct_branch_call( + package, + package_id, + &orig_callee, + &orig_args, + span, + &result_ty, + entry, + assigner, + ); + current_else = alloc_if_expr( + package, + span, + &result_ty, + cond_id, + branch_call_id, + current_else, + assigner, + ); + } + + let dispatch = package + .exprs + .get(current_else) + .expect("dispatch expr should exist") + .clone(); + let orig = package + .exprs + .get_mut(call_expr_id) + .expect("call expr should exist"); + orig.kind = dispatch.kind; + orig.ty = dispatch.ty; +} + +/// Records a local variable whose call-site rewrite now references a +/// specialized callable, marking it eligible for the dead-local cleanup +/// subsystem. +fn collect_rewritten_callable_arg_local( + package: &Package, + expr_owner_lookup: &FxHashMap, + call_expr_id: ExprId, + expr_id: ExprId, + rewritten_callable_arg_locals: &mut FxHashSet<(LocalItemId, LocalVarId)>, +) { + let expr = package.get_expr(expr_id); + if let ExprKind::Var(Res::Local(var), _) = expr.kind + && let Some(&callable_id) = expr_owner_lookup.get(&call_expr_id) + { + rewritten_callable_arg_locals.insert((callable_id, var)); + } +} + +/// Synthesizes an index-dispatch `if`/`else` chain for a HOF call site that +/// resolves to multiple callables via branch-split analysis. +fn synthesize_callsite_index_dispatch( + package: &mut Package, + expr_owner_lookup: &FxHashMap, + call_expr_id: ExprId, + entries: &[(&CallSite, LocalItemId, &CallableParam)], + span: Span, + assigner: &mut Assigner, +) -> Option<(Vec<(usize, ExprId)>, usize)> { + let callables = entries + .iter() + .map(|(call_site, _, _)| call_site.callable_arg.clone()) + .collect::>(); + synthesize_index_dispatch_plan( + package, + expr_owner_lookup, + call_expr_id, + entries.first()?.0.arg_expr_id, + &callables, + span, + assigner, + ) +} + +/// Synthesizes an index-dispatch `if`/`else` chain for a direct-call site +/// whose callee expression resolves to multiple concrete callables. +fn synthesize_direct_index_dispatch( + package: &mut Package, + expr_owner_lookup: &FxHashMap, + call_expr_id: ExprId, + entries: &[&DirectCallSite], + span: Span, + assigner: &mut Assigner, +) -> Option<(Vec<(usize, ExprId)>, usize)> { + let ExprKind::Call(callee_id, _) = package.get_expr(call_expr_id).kind else { + return None; + }; + let callables = entries + .iter() + .map(|entry| entry.callable.clone()) + .collect::>(); + synthesize_index_dispatch_plan( + package, + expr_owner_lookup, + call_expr_id, + callee_id, + &callables, + span, + assigner, + ) +} + +/// Plans the branches of an index-dispatch rewrite by pairing each +/// candidate callable with the condition expression that selects it. +fn synthesize_index_dispatch_plan( + package: &mut Package, + expr_owner_lookup: &FxHashMap, + owner_expr_id: ExprId, + dispatch_expr_id: ExprId, + callables: &[ConcreteCallable], + span: Span, + assigner: &mut Assigner, +) -> Option<(Vec<(usize, ExprId)>, usize)> { + if callables.len() < 2 { + return None; + } + + let (index_expr_id, indexed_callables) = + resolve_index_dispatch_source(package, expr_owner_lookup, owner_expr_id, dispatch_expr_id)?; + + let mut entry_positions = Vec::with_capacity(callables.len()); + for callable in callables { + let position = indexed_callables + .iter() + .position(|candidate| candidate == callable)?; + entry_positions.push(position); + } + + let (default_idx, _) = entry_positions + .iter() + .copied() + .enumerate() + .max_by_key(|(_, position)| *position)?; + + let mut conditioned = Vec::with_capacity(callables.len().saturating_sub(1)); + for (entry_idx, position) in entry_positions.into_iter().enumerate() { + if entry_idx == default_idx { + continue; + } + let condition = alloc_index_eq_expr(package, index_expr_id, position, span, assigner); + conditioned.push((entry_idx, condition)); + } + + Some((conditioned, default_idx)) +} + +/// Locates the source of a dynamic dispatch (for example the index +/// expression selecting an element in a callable array) that +/// `synthesize_*_index_dispatch` will compare against per-branch values. +fn resolve_index_dispatch_source( + package: &Package, + expr_owner_lookup: &FxHashMap, + owner_expr_id: ExprId, + dispatch_expr_id: ExprId, +) -> Option<(ExprId, Vec)> { + let source_expr_id = + resolve_dispatch_source_expr(package, expr_owner_lookup, owner_expr_id, dispatch_expr_id)?; + let ExprKind::Index(array_expr_id, index_expr_id) = package.get_expr(source_expr_id).kind + else { + return None; + }; + + // Try direct resolution: array elements are callables. + if let Some(indexed_callables) = + resolve_array_expr_to_callables(package, expr_owner_lookup, owner_expr_id, array_expr_id) + && indexed_callables.len() >= 2 + { + return Some((index_expr_id, indexed_callables)); + } + + // Direct resolution failed: array elements may be tuples. + // Check if the dispatch expression was a local variable bound from a + // tuple pattern, and try extracting the appropriate field from each + // array element before resolving. + let field_path = + resolve_dispatch_field_path(package, expr_owner_lookup, owner_expr_id, dispatch_expr_id)?; + let indexed_callables = resolve_array_expr_to_callables_with_field( + package, + expr_owner_lookup, + owner_expr_id, + array_expr_id, + &field_path, + )?; + if indexed_callables.len() < 2 { + return None; + } + Some((index_expr_id, indexed_callables)) +} + +/// For a dispatch expression that is a local variable bound from a tuple +/// pattern, returns the field position path within the tuple. +fn resolve_dispatch_field_path( + package: &Package, + expr_owner_lookup: &FxHashMap, + owner_expr_id: ExprId, + dispatch_expr_id: ExprId, +) -> Option> { + let expr = package.get_expr(dispatch_expr_id); + if let ExprKind::Var(Res::Local(local_var), _) = expr.kind { + let owner_callable = *expr_owner_lookup.get(&owner_expr_id)?; + find_var_tuple_field_path_in_callable(package, owner_callable, local_var) + } else { + None + } +} + +/// Resolves the expression feeding an index dispatch back to its defining +/// source (literal, local, or field access) so per-branch conditions can +/// compare directly against it. +fn resolve_dispatch_source_expr( + package: &Package, + expr_owner_lookup: &FxHashMap, + owner_expr_id: ExprId, + expr_id: ExprId, +) -> Option { + let expr = package.get_expr(expr_id); + match expr.kind { + ExprKind::Var(Res::Local(local_var), _) => { + let owner_callable = *expr_owner_lookup.get(&owner_expr_id)?; + let init_expr_id = + find_local_init_expr_in_callable(package, owner_callable, local_var)?; + if init_expr_id == expr_id { + None + } else { + resolve_dispatch_source_expr( + package, + expr_owner_lookup, + owner_expr_id, + init_expr_id, + ) + } + } + ExprKind::Block(block_id) => { + let block = package.get_block(block_id); + let stmt_id = *block.stmts.last()?; + let stmt = package.get_stmt(stmt_id); + #[allow(clippy::manual_let_else)] + let tail_expr_id = match stmt.kind { + StmtKind::Expr(expr_id) | StmtKind::Semi(expr_id) => expr_id, + _ => return None, + }; + resolve_dispatch_source_expr(package, expr_owner_lookup, owner_expr_id, tail_expr_id) + } + ExprKind::Return(inner_expr_id) => { + resolve_dispatch_source_expr(package, expr_owner_lookup, owner_expr_id, inner_expr_id) + } + _ => Some(expr_id), + } +} + +/// Resolves an array-literal expression to the ordered list of concrete +/// callables it contains, used by index-dispatch synthesis. +fn resolve_array_expr_to_callables( + package: &Package, + expr_owner_lookup: &FxHashMap, + owner_expr_id: ExprId, + expr_id: ExprId, +) -> Option> { + let source_expr_id = + resolve_dispatch_source_expr(package, expr_owner_lookup, owner_expr_id, expr_id)?; + let expr = package.get_expr(source_expr_id); + let elements = match &expr.kind { + ExprKind::Array(elements) | ExprKind::ArrayLit(elements) | ExprKind::Tuple(elements) => { + elements.clone() + } + _ => return None, + }; + + let mut callables = Vec::with_capacity(elements.len()); + for elem_expr_id in elements { + let callable = resolve_expr_to_concrete_callable( + package, + expr_owner_lookup, + owner_expr_id, + elem_expr_id, + )?; + if !callables.contains(&callable) { + callables.push(callable); + } + } + + Some(callables) +} + +/// Extracts a nested tuple field from an expression by following a field path. +/// For `field_path = [1]`, returns the second element of a tuple expression. +fn extract_tuple_field(package: &Package, expr_id: ExprId, path: &[usize]) -> Option { + let mut current = expr_id; + for &idx in path { + let expr = package.get_expr(current); + if let ExprKind::Tuple(fields) = &expr.kind { + current = *fields.get(idx)?; + } else { + return None; + } + } + Some(current) +} + +/// Like `resolve_array_expr_to_callables`, but first extracts the tuple field +/// at `field_path` from each array element before resolving to a callable. +fn resolve_array_expr_to_callables_with_field( + package: &Package, + expr_owner_lookup: &FxHashMap, + owner_expr_id: ExprId, + array_expr_id: ExprId, + field_path: &[usize], +) -> Option> { + let source_expr_id = + resolve_dispatch_source_expr(package, expr_owner_lookup, owner_expr_id, array_expr_id)?; + let expr = package.get_expr(source_expr_id); + let elements = match &expr.kind { + ExprKind::Array(elements) | ExprKind::ArrayLit(elements) | ExprKind::Tuple(elements) => { + elements.clone() + } + _ => return None, + }; + + let mut callables = Vec::with_capacity(elements.len()); + for elem_expr_id in elements { + let field_expr_id = extract_tuple_field(package, elem_expr_id, field_path)?; + let callable = resolve_expr_to_concrete_callable( + package, + expr_owner_lookup, + owner_expr_id, + field_expr_id, + )?; + if !callables.contains(&callable) { + callables.push(callable); + } + } + + Some(callables) +} + +/// Attempts to resolve an expression to a single concrete callable (global +/// or closure), mirroring the analysis-phase resolution but on the +/// rewritten package. +fn resolve_expr_to_concrete_callable( + package: &Package, + expr_owner_lookup: &FxHashMap, + owner_expr_id: ExprId, + expr_id: ExprId, +) -> Option { + let source_expr_id = + resolve_dispatch_source_expr(package, expr_owner_lookup, owner_expr_id, expr_id)?; + let (base_id, functor) = peel_body_functors(package, source_expr_id); + let expr = package.get_expr(base_id); + match expr.kind { + ExprKind::Var(Res::Item(item_id), _) => Some(ConcreteCallable::Global { item_id, functor }), + _ => None, + } +} + +/// Allocates a `BinOp(Eq, index_expr, Int(index_value))` expression used as +/// the condition guard for index-dispatch branches. +/// +/// # Before +/// ```text +/// (no expression) +/// ``` +/// # After +/// ```text +/// Expr { BinOp(Eq, index_expr, Lit(Int(index_value))) : Bool } +/// ``` +/// +/// # Mutations +/// - Inserts two new `Expr` nodes (literal + comparison) through `assigner`. +fn alloc_index_eq_expr( + package: &mut Package, + index_expr_id: ExprId, + index_value: usize, + span: Span, + assigner: &mut Assigner, +) -> ExprId { + let lit_id = assigner.next_expr(); + let index_value = i64::try_from(index_value).expect("dispatch index should fit in i64"); + package.exprs.insert( + lit_id, + Expr { + id: lit_id, + span, + ty: Ty::Prim(Prim::Int), + kind: ExprKind::Lit(Lit::Int(index_value)), + exec_graph_range: EMPTY_EXEC_RANGE, + }, + ); + + let cond_id = assigner.next_expr(); + package.exprs.insert( + cond_id, + Expr { + id: cond_id, + span, + ty: Ty::Prim(Prim::Bool), + kind: ExprKind::BinOp(BinOp::Eq, index_expr_id, lit_id), + exec_graph_range: EMPTY_EXEC_RANGE, + }, + ); + cond_id +} + +/// Locates the initializer expression for a given local variable inside a +/// reachable callable body. +fn find_local_init_expr_in_callable( + package: &Package, + callable_id: LocalItemId, + local_var: LocalVarId, +) -> Option { + let Some(ItemKind::Callable(decl)) = package.items.get(callable_id).map(|item| &item.kind) + else { + return None; + }; + + find_local_init_expr_in_callable_impl(package, &decl.implementation, local_var) +} + +/// Recurses over a `CallableImpl` variant to locate a local variable's +/// initializer expression. +fn find_local_init_expr_in_callable_impl( + package: &Package, + callable_impl: &qsc_fir::fir::CallableImpl, + local_var: LocalVarId, +) -> Option { + match callable_impl { + qsc_fir::fir::CallableImpl::Intrinsic => None, + qsc_fir::fir::CallableImpl::SimulatableIntrinsic(spec_decl) => { + find_local_init_expr_in_block(package, spec_decl.block, local_var) + } + qsc_fir::fir::CallableImpl::Spec(spec_impl) => { + find_local_init_expr_in_block(package, spec_impl.body.block, local_var).or_else(|| { + [ + spec_impl.adj.as_ref(), + spec_impl.ctl.as_ref(), + spec_impl.ctl_adj.as_ref(), + ] + .into_iter() + .flatten() + .find_map(|spec| find_local_init_expr_in_block(package, spec.block, local_var)) + }) + } + } +} + +/// Walks a block's statements looking for the `Local` binding of the +/// requested local variable. +fn find_local_init_expr_in_block( + package: &Package, + block_id: qsc_fir::fir::BlockId, + local_var: LocalVarId, +) -> Option { + let block = package.get_block(block_id); + for &stmt_id in &block.stmts { + let stmt = package.get_stmt(stmt_id); + if let StmtKind::Local(_, pat_id, init_expr_id) = stmt.kind + && pat_binds_local_var(package, pat_id, local_var) + { + return Some(init_expr_id); + } + + let nested = match stmt.kind { + StmtKind::Expr(expr_id) | StmtKind::Semi(expr_id) | StmtKind::Local(_, _, expr_id) => { + find_local_init_expr_in_expr(package, expr_id, local_var) + } + StmtKind::Item(_) => None, + }; + if nested.is_some() { + return nested; + } + } + + None +} + +/// Descends into nested expressions (blocks, conditionals, loops) while +/// searching for a local variable's initializer. +fn find_local_init_expr_in_expr( + package: &Package, + expr_id: ExprId, + local_var: LocalVarId, +) -> Option { + let expr = package.get_expr(expr_id); + match &expr.kind { + ExprKind::Array(exprs) | ExprKind::ArrayLit(exprs) | ExprKind::Tuple(exprs) => exprs + .iter() + .find_map(|&expr_id| find_local_init_expr_in_expr(package, expr_id, local_var)), + ExprKind::ArrayRepeat(lhs, rhs) + | ExprKind::Assign(lhs, rhs) + | ExprKind::AssignOp(_, lhs, rhs) + | ExprKind::BinOp(_, lhs, rhs) + | ExprKind::Call(lhs, rhs) + | ExprKind::Index(lhs, rhs) + | ExprKind::AssignField(lhs, _, rhs) + | ExprKind::UpdateField(lhs, _, rhs) => { + find_local_init_expr_in_expr(package, *lhs, local_var) + .or_else(|| find_local_init_expr_in_expr(package, *rhs, local_var)) + } + ExprKind::AssignIndex(a, b, c) | ExprKind::UpdateIndex(a, b, c) => { + find_local_init_expr_in_expr(package, *a, local_var) + .or_else(|| find_local_init_expr_in_expr(package, *b, local_var)) + .or_else(|| find_local_init_expr_in_expr(package, *c, local_var)) + } + ExprKind::Block(block_id) => find_local_init_expr_in_block(package, *block_id, local_var), + ExprKind::Fail(inner) + | ExprKind::Field(inner, _) + | ExprKind::Return(inner) + | ExprKind::UnOp(_, inner) => find_local_init_expr_in_expr(package, *inner, local_var), + ExprKind::If(cond, body, otherwise) => { + find_local_init_expr_in_expr(package, *cond, local_var) + .or_else(|| find_local_init_expr_in_expr(package, *body, local_var)) + .or_else(|| { + otherwise.and_then(|expr_id| { + find_local_init_expr_in_expr(package, expr_id, local_var) + }) + }) + } + ExprKind::Range(start, step, end) => start + .and_then(|expr_id| find_local_init_expr_in_expr(package, expr_id, local_var)) + .or_else(|| { + step.and_then(|expr_id| find_local_init_expr_in_expr(package, expr_id, local_var)) + }) + .or_else(|| { + end.and_then(|expr_id| find_local_init_expr_in_expr(package, expr_id, local_var)) + }), + ExprKind::String(components) => components.iter().find_map(|component| match component { + qsc_fir::fir::StringComponent::Expr(expr_id) => { + find_local_init_expr_in_expr(package, *expr_id, local_var) + } + qsc_fir::fir::StringComponent::Lit(_) => None, + }), + ExprKind::Struct(_, copy, fields) => copy + .and_then(|expr_id| find_local_init_expr_in_expr(package, expr_id, local_var)) + .or_else(|| { + fields + .iter() + .find_map(|field| find_local_init_expr_in_expr(package, field.value, local_var)) + }), + ExprKind::While(cond, block_id) => find_local_init_expr_in_expr(package, *cond, local_var) + .or_else(|| find_local_init_expr_in_block(package, *block_id, local_var)), + ExprKind::Closure(_, _) | ExprKind::Hole | ExprKind::Lit(_) | ExprKind::Var(_, _) => None, + } +} + +/// Removes callable-typed argument locals whose only remaining uses were +/// rewritten into direct dispatch calls, leaving no arrow-typed residue. +/// +/// # Before +/// ```text +/// { let f = some_callable; specialized_call(args); } +/// ``` +/// # After +/// ```text +/// { specialized_call(args); } // dead `let f` removed +/// ``` +/// +/// # Mutations +/// - Removes `Local` binding statements and `Var` references for dead locals +/// via [`remove_dead_callable_local_from_callable`] and +/// [`prune_dead_top_level_callable_locals`]. +fn prune_dead_callable_arg_locals( + package: &mut Package, + rewritten_callable_arg_locals: &FxHashSet<(LocalItemId, LocalVarId)>, +) { + for &(callable_id, local_var) in rewritten_callable_arg_locals { + if !local_var_is_used_in_callable(package, callable_id, local_var) { + remove_dead_callable_local_from_callable(package, callable_id, local_var); + } + } + + prune_dead_top_level_callable_locals(package); +} + +fn build_expr_owner_lookup(package: &Package) -> FxHashMap { + let mut expr_owner_lookup = FxHashMap::default(); + + for (item_id, item) in &package.items { + if let ItemKind::Callable(decl) = &item.kind { + crate::walk_utils::for_each_expr_in_callable_impl( + package, + &decl.implementation, + &mut |expr_id, _expr| { + expr_owner_lookup.insert(expr_id, item_id); + }, + ); + } + } + + expr_owner_lookup +} + +fn local_var_is_used_in_callable( + package: &Package, + callable_id: LocalItemId, + local_var: LocalVarId, +) -> bool { + let Some(ItemKind::Callable(decl)) = package.items.get(callable_id).map(|item| &item.kind) + else { + return false; + }; + + let mut used = false; + crate::walk_utils::for_each_expr_in_callable_impl( + package, + &decl.implementation, + &mut |_expr_id, expr| { + if matches!(expr.kind, ExprKind::Var(Res::Local(var), _) if var == local_var) { + used = true; + } + }, + ); + used +} + +/// Removes a specific dead callable local from the given callable's body +/// by deleting its `Local` binding and any references that remain. +/// +/// # Before +/// ```text +/// body { let f : Arrow = init; ... /* no uses of f */ ... } +/// ``` +/// # After +/// ```text +/// body { ... } +/// ``` +/// +/// # Mutations +/// - Filters `Block.stmts` to remove the `Local` binding for `local_var`. +/// - Recurses into nested blocks via [`remove_dead_callable_local_from_block`]. +fn remove_dead_callable_local_from_callable( + package: &mut Package, + callable_id: LocalItemId, + local_var: LocalVarId, +) { + let Some(ItemKind::Callable(decl)) = package.items.get(callable_id).map(|item| &item.kind) + else { + return; + }; + + let implementation = decl.implementation.clone(); + match implementation { + qsc_fir::fir::CallableImpl::Intrinsic => {} + qsc_fir::fir::CallableImpl::SimulatableIntrinsic(spec_decl) => { + remove_dead_callable_local_from_block(package, spec_decl.block, local_var); + } + qsc_fir::fir::CallableImpl::Spec(spec_impl) => { + remove_dead_callable_local_from_block(package, spec_impl.body.block, local_var); + for spec in [spec_impl.adj, spec_impl.ctl, spec_impl.ctl_adj] + .into_iter() + .flatten() + { + remove_dead_callable_local_from_block(package, spec.block, local_var); + } + } + } +} + +/// Removes top-level callable-typed locals whose only uses were direct +/// dispatch rewrites, scoped to the package-level entry expression. +/// +/// # Before +/// ```text +/// body { let g : Arrow = ...; /* no remaining uses of g */ ... } +/// ``` +/// # After +/// ```text +/// body { ... } // dead binding removed +/// ``` +/// +/// # Mutations +/// - Filters `Block.stmts` across all callable bodies in the package. +fn prune_dead_top_level_callable_locals(package: &mut Package) { + let callable_items: Vec<(LocalItemId, qsc_fir::fir::CallableImpl)> = package + .items + .iter() + .filter_map(|(item_id, item)| match &item.kind { + ItemKind::Callable(decl) => Some((item_id, decl.implementation.clone())), + _ => None, + }) + .collect(); + + for (_item_id, implementation) in callable_items { + match implementation { + qsc_fir::fir::CallableImpl::Intrinsic => {} + qsc_fir::fir::CallableImpl::SimulatableIntrinsic(spec_decl) => { + prune_dead_callable_locals_in_block(package, spec_decl.block); + } + qsc_fir::fir::CallableImpl::Spec(spec_impl) => { + prune_dead_callable_locals_in_block(package, spec_impl.body.block); + for spec in [spec_impl.adj, spec_impl.ctl, spec_impl.ctl_adj] + .into_iter() + .flatten() + { + prune_dead_callable_locals_in_block(package, spec.block); + } + } + } + } +} + +/// Walks a block looking for dead callable-typed locals introduced by +/// direct-call rewrites and removes them in place. +/// +/// Iterates until no more removals occur so that cascading dead-local chains +/// (e.g. `let a = closure; let b = a;`) are fully pruned in a single call +/// rather than requiring multiple outer fixpoint iterations. +/// +/// # Before +/// ```text +/// { let a : Arrow = closure; let b : Arrow = a; specialized_call(args); } +/// ``` +/// # After +/// ```text +/// { specialized_call(args); } // both dead bindings removed +/// ``` +/// +/// # Mutations +/// - Rewrites `Block.stmts` to drop unused `Local` bindings in a fixpoint +/// loop, then recurses into nested blocks. +fn prune_dead_callable_locals_in_block(package: &mut Package, block_id: qsc_fir::fir::BlockId) { + loop { + let stmt_ids = package.get_block(block_id).stmts.clone(); + let initial_count = stmt_ids.len(); + let mut retained = Vec::with_capacity(initial_count); + + for stmt_id in stmt_ids { + let stmt = package.get_stmt(stmt_id); + let remove_stmt = match stmt.kind { + StmtKind::Local(Mutability::Immutable, pat_id, _) => { + let pat = package.get_pat(pat_id); + if local_ty_contains_arrow_through_udts(package, &pat.ty) { + let mut bound_vars = Vec::new(); + collect_bound_pat_vars(package, pat_id, &mut bound_vars); + !bound_vars.is_empty() + && bound_vars.iter().all(|var| { + let mut uses = Vec::new(); + crate::walk_utils::collect_uses_in_block( + package, block_id, *var, &mut uses, + ); + uses.is_empty() + }) + } else { + false + } + } + _ => false, + }; + + if !remove_stmt { + retained.push(stmt_id); + } + } + + package + .blocks + .get_mut(block_id) + .expect("block should exist") + .stmts + .clone_from(&retained); + + if retained.len() == initial_count { + // No removals this pass — walk nested blocks and stop. + for stmt_id in retained { + prune_dead_callable_locals_in_stmt(package, stmt_id); + } + break; + } + } +} + +/// Removes a dead callable local scoped to a specific block, including its +/// `Local` binding and any remaining references. +/// +/// # Before +/// ```text +/// { let f : Arrow = init; stmt1; stmt2; } +/// ``` +/// # After +/// ```text +/// { stmt1; stmt2; } // binding removed when f is unused +/// ``` +/// +/// # Mutations +/// - Filters `Block.stmts` to remove the dead binding. +/// - Recurses into nested blocks via [`remove_dead_callable_local_from_stmt`]. +fn remove_dead_callable_local_from_block( + package: &mut Package, + block_id: qsc_fir::fir::BlockId, + local_var: LocalVarId, +) { + let stmt_ids = package.get_block(block_id).stmts.clone(); + let mut retained = Vec::with_capacity(stmt_ids.len()); + + for stmt_id in stmt_ids { + let stmt = package.get_stmt(stmt_id); + let remove_stmt = if let StmtKind::Local(Mutability::Immutable, pat_id, _) = stmt.kind + && local_ty_contains_arrow_through_udts(package, &package.get_pat(pat_id).ty) + && pat_binds_local_var(package, pat_id, local_var) + { + // Only remove when ALL bound variables in the pattern are + // unused; a tuple pattern may bind siblings that are still live. + let mut bound_vars = Vec::new(); + collect_bound_pat_vars(package, pat_id, &mut bound_vars); + bound_vars.iter().all(|&var| { + let mut uses = Vec::new(); + crate::walk_utils::collect_uses_in_block(package, block_id, var, &mut uses); + uses.is_empty() + }) + } else { + false + }; + + if !remove_stmt { + retained.push(stmt_id); + } + } + + let retained_for_walk = retained.clone(); + package + .blocks + .get_mut(block_id) + .expect("block should exist") + .stmts = retained; + + for stmt_id in retained_for_walk { + remove_dead_callable_local_from_stmt(package, stmt_id, local_var); + } +} + +/// Inspects a single statement for dead callable-local bindings and +/// deletes them when safe. +/// +/// # Mutations +/// - Delegates to [`prune_dead_callable_locals_in_expr`] for the +/// statement's inner expression. +fn prune_dead_callable_locals_in_stmt(package: &mut Package, stmt_id: qsc_fir::fir::StmtId) { + let stmt = package.get_stmt(stmt_id).clone(); + match stmt.kind { + StmtKind::Expr(expr_id) | StmtKind::Semi(expr_id) | StmtKind::Local(_, _, expr_id) => { + prune_dead_callable_locals_in_expr(package, expr_id); + } + StmtKind::Item(_) => {} + } +} + +/// Descends into an expression subtree looking for dead callable-local +/// bindings introduced by direct-call rewrites. +/// +/// # Mutations +/// - Delegates to [`prune_dead_callable_locals_in_block`] for nested +/// `Block` and `While` bodies, recursing until all dead bindings are +/// removed. +fn prune_dead_callable_locals_in_expr(package: &mut Package, expr_id: ExprId) { + let expr = package.get_expr(expr_id).clone(); + match expr.kind { + ExprKind::Array(exprs) | ExprKind::ArrayLit(exprs) | ExprKind::Tuple(exprs) => { + for expr_id in exprs { + prune_dead_callable_locals_in_expr(package, expr_id); + } + } + ExprKind::ArrayRepeat(lhs, rhs) + | ExprKind::Assign(lhs, rhs) + | ExprKind::AssignOp(_, lhs, rhs) + | ExprKind::BinOp(_, lhs, rhs) + | ExprKind::Call(lhs, rhs) + | ExprKind::Index(lhs, rhs) + | ExprKind::AssignField(lhs, _, rhs) + | ExprKind::UpdateField(lhs, _, rhs) => { + prune_dead_callable_locals_in_expr(package, lhs); + prune_dead_callable_locals_in_expr(package, rhs); + } + ExprKind::AssignIndex(a, b, c) | ExprKind::UpdateIndex(a, b, c) => { + prune_dead_callable_locals_in_expr(package, a); + prune_dead_callable_locals_in_expr(package, b); + prune_dead_callable_locals_in_expr(package, c); + } + ExprKind::Block(block_id) => prune_dead_callable_locals_in_block(package, block_id), + ExprKind::Fail(inner) + | ExprKind::Field(inner, _) + | ExprKind::Return(inner) + | ExprKind::UnOp(_, inner) => prune_dead_callable_locals_in_expr(package, inner), + ExprKind::If(cond, body, otherwise) => { + prune_dead_callable_locals_in_expr(package, cond); + prune_dead_callable_locals_in_expr(package, body); + if let Some(otherwise) = otherwise { + prune_dead_callable_locals_in_expr(package, otherwise); + } + } + ExprKind::Range(start, step, end) => { + for expr_id in [start, step, end].into_iter().flatten() { + prune_dead_callable_locals_in_expr(package, expr_id); + } + } + ExprKind::String(components) => { + for component in components { + if let qsc_fir::fir::StringComponent::Expr(expr_id) = component { + prune_dead_callable_locals_in_expr(package, expr_id); + } + } + } + ExprKind::Struct(_, copy, fields) => { + if let Some(copy) = copy { + prune_dead_callable_locals_in_expr(package, copy); + } + for field in fields { + prune_dead_callable_locals_in_expr(package, field.value); + } + } + ExprKind::While(cond, block_id) => { + prune_dead_callable_locals_in_expr(package, cond); + prune_dead_callable_locals_in_block(package, block_id); + } + ExprKind::Closure(_, _) | ExprKind::Hole | ExprKind::Lit(_) | ExprKind::Var(_, _) => {} + } +} + +/// Removes a specific dead callable local scoped to a single statement. +/// +/// # Mutations +/// - Delegates to [`remove_dead_callable_local_from_expr`] for the +/// statement's inner expression. +fn remove_dead_callable_local_from_stmt( + package: &mut Package, + stmt_id: qsc_fir::fir::StmtId, + local_var: LocalVarId, +) { + let stmt = package.get_stmt(stmt_id).clone(); + match stmt.kind { + StmtKind::Expr(expr_id) | StmtKind::Semi(expr_id) | StmtKind::Local(_, _, expr_id) => { + remove_dead_callable_local_from_expr(package, expr_id, local_var); + } + StmtKind::Item(_) => {} + } +} + +/// Removes references to a dead callable local inside a given expression +/// subtree. +/// +/// # Mutations +/// - Recurses through `Block`, `If`, `While`, and compound expressions +/// to reach every nested block via +/// [`remove_dead_callable_local_from_block`]. +fn remove_dead_callable_local_from_expr( + package: &mut Package, + expr_id: ExprId, + local_var: LocalVarId, +) { + let expr = package.get_expr(expr_id).clone(); + match expr.kind { + ExprKind::Array(exprs) | ExprKind::ArrayLit(exprs) | ExprKind::Tuple(exprs) => { + for expr_id in exprs { + remove_dead_callable_local_from_expr(package, expr_id, local_var); + } + } + ExprKind::ArrayRepeat(lhs, rhs) + | ExprKind::Assign(lhs, rhs) + | ExprKind::AssignOp(_, lhs, rhs) + | ExprKind::BinOp(_, lhs, rhs) + | ExprKind::Call(lhs, rhs) + | ExprKind::Index(lhs, rhs) + | ExprKind::AssignField(lhs, _, rhs) + | ExprKind::UpdateField(lhs, _, rhs) => { + remove_dead_callable_local_from_expr(package, lhs, local_var); + remove_dead_callable_local_from_expr(package, rhs, local_var); + } + ExprKind::AssignIndex(a, b, c) | ExprKind::UpdateIndex(a, b, c) => { + remove_dead_callable_local_from_expr(package, a, local_var); + remove_dead_callable_local_from_expr(package, b, local_var); + remove_dead_callable_local_from_expr(package, c, local_var); + } + ExprKind::Block(block_id) => { + remove_dead_callable_local_from_block(package, block_id, local_var); + } + ExprKind::Fail(inner) + | ExprKind::Field(inner, _) + | ExprKind::Return(inner) + | ExprKind::UnOp(_, inner) => { + remove_dead_callable_local_from_expr(package, inner, local_var); + } + ExprKind::If(cond, body, otherwise) => { + remove_dead_callable_local_from_expr(package, cond, local_var); + remove_dead_callable_local_from_expr(package, body, local_var); + if let Some(otherwise) = otherwise { + remove_dead_callable_local_from_expr(package, otherwise, local_var); + } + } + ExprKind::Range(start, step, end) => { + for expr_id in [start, step, end].into_iter().flatten() { + remove_dead_callable_local_from_expr(package, expr_id, local_var); + } + } + ExprKind::String(components) => { + for component in components { + if let qsc_fir::fir::StringComponent::Expr(expr_id) = component { + remove_dead_callable_local_from_expr(package, expr_id, local_var); + } + } + } + ExprKind::Struct(_, copy, fields) => { + if let Some(copy) = copy { + remove_dead_callable_local_from_expr(package, copy, local_var); + } + for field in fields { + remove_dead_callable_local_from_expr(package, field.value, local_var); + } + } + ExprKind::While(cond, block_id) => { + remove_dead_callable_local_from_expr(package, cond, local_var); + remove_dead_callable_local_from_block(package, block_id, local_var); + } + ExprKind::Closure(_, _) | ExprKind::Hole | ExprKind::Lit(_) | ExprKind::Var(_, _) => {} + } +} + +fn collect_bound_pat_vars(package: &Package, pat_id: PatId, bound_vars: &mut Vec) { + let pat = package.get_pat(pat_id); + match &pat.kind { + PatKind::Bind(ident) => bound_vars.push(ident.id), + PatKind::Discard => {} + PatKind::Tuple(pats) => { + for &sub_pat_id in pats { + collect_bound_pat_vars(package, sub_pat_id, bound_vars); + } + } + } +} + +fn pat_binds_local_var(package: &Package, pat_id: PatId, local_var: LocalVarId) -> bool { + let mut bound_vars = Vec::new(); + collect_bound_pat_vars(package, pat_id, &mut bound_vars); + bound_vars + .into_iter() + .any(|bound_var| bound_var == local_var) +} + +/// For a local variable bound inside a tuple pattern (e.g., +/// `let (_, callee, _) = tuple_expr`), returns the field position +/// path (e.g., `[1]` for position 1). +fn find_var_tuple_field_path_in_callable( + package: &Package, + callable_id: LocalItemId, + local_var: LocalVarId, +) -> Option> { + let item = package.items.get(callable_id)?; + let ItemKind::Callable(decl) = &item.kind else { + return None; + }; + match &decl.implementation { + qsc_fir::fir::CallableImpl::Intrinsic => None, + qsc_fir::fir::CallableImpl::SimulatableIntrinsic(spec_decl) => { + find_var_tuple_field_path_in_block(package, spec_decl.block, local_var) + } + qsc_fir::fir::CallableImpl::Spec(spec_impl) => find_var_tuple_field_path_in_block( + package, + spec_impl.body.block, + local_var, + ) + .or_else(|| { + [ + spec_impl.adj.as_ref(), + spec_impl.ctl.as_ref(), + spec_impl.ctl_adj.as_ref(), + ] + .into_iter() + .flatten() + .find_map(|spec| find_var_tuple_field_path_in_block(package, spec.block, local_var)) + }), + } +} + +/// Walks a block's statements looking for a `PatKind::Tuple` binding that +/// contains the requested local variable. +fn find_var_tuple_field_path_in_block( + package: &Package, + block_id: qsc_fir::fir::BlockId, + local_var: LocalVarId, +) -> Option> { + let block = package.get_block(block_id); + for &stmt_id in &block.stmts { + let stmt = package.get_stmt(stmt_id); + if let StmtKind::Local(_, pat_id, _) = stmt.kind + && let Some(path) = find_var_field_path_in_pat(package, pat_id, local_var) + && !path.is_empty() + { + return Some(path); + } + // Also descend into nested blocks and control flow + let nested = match stmt.kind { + StmtKind::Expr(expr_id) | StmtKind::Semi(expr_id) | StmtKind::Local(_, _, expr_id) => { + find_var_tuple_field_path_in_expr(package, expr_id, local_var) + } + StmtKind::Item(_) => None, + }; + if nested.is_some() { + return nested; + } + } + None +} + +/// Descends into nested expressions (blocks, conditionals, loops) to find +/// the tuple field path of a local variable binding. +fn find_var_tuple_field_path_in_expr( + package: &Package, + expr_id: ExprId, + local_var: LocalVarId, +) -> Option> { + let expr = package.get_expr(expr_id); + match &expr.kind { + ExprKind::Block(block_id) | ExprKind::While(_, block_id) => { + find_var_tuple_field_path_in_block(package, *block_id, local_var) + } + ExprKind::If(_, body, otherwise) => { + find_var_tuple_field_path_in_expr(package, *body, local_var).or_else(|| { + otherwise.and_then(|e| find_var_tuple_field_path_in_expr(package, e, local_var)) + }) + } + _ => None, + } +} + +/// Recursively finds the tuple field path for a local variable within a +/// pattern tree. Returns `Some(vec![])` for a direct bind, +/// `Some(vec![1])` for position 1 in a tuple pattern, etc. +fn find_var_field_path_in_pat( + package: &Package, + pat_id: PatId, + local_var: LocalVarId, +) -> Option> { + let pat = package.get_pat(pat_id); + match &pat.kind { + PatKind::Bind(ident) if ident.id == local_var => Some(Vec::new()), + PatKind::Bind(_) | PatKind::Discard => None, + PatKind::Tuple(sub_pats) => { + for (i, &sub_pat_id) in sub_pats.iter().enumerate() { + if let Some(mut path) = find_var_field_path_in_pat(package, sub_pat_id, local_var) { + path.insert(0, i); + return Some(path); + } + } + None + } + } +} + +/// Rewrites the callee expression of a direct call to reference the +/// specialized target callable and updates its type accordingly. +/// +/// # Before +/// ```text +/// Var(original_item) : OldArrow // callee expr +/// ``` +/// # After +/// ```text +/// Var(specialized_item) : NewArrow // callee replaced and retyped +/// ``` +/// +/// # Mutations +/// - Overwrites the callee `Expr` node in place via +/// [`rewrite_item_callee_with_functor`]. +/// - May allocate functor-wrapper `Expr` nodes through `assigner`. +fn rewrite_direct_callee( + package: &mut Package, + package_id: PackageId, + callee_id: ExprId, + callable: &ConcreteCallable, + _captures: &[CapturedVar], + controlled_layers: usize, + assigner: &mut Assigner, +) { + let callee_expr = package.get_expr(callee_id).clone(); + let (item_id, functor, callee_ty) = match callable { + ConcreteCallable::Global { item_id, functor } => { + let callee_ty = if item_id.package == package_id + && direct_lambda_packaged_input(package, item_id.item).is_some() + { + build_direct_global_callee_ty(package, *item_id, &callee_expr.ty, controlled_layers) + .unwrap_or_else(|| callee_expr.ty.clone()) + } else { + callee_expr.ty.clone() + }; + (*item_id, *functor, callee_ty) + } + ConcreteCallable::Closure { + target, functor, .. + } => { + let item_id = ItemId { + package: package_id, + item: *target, + }; + ( + item_id, + *functor, + build_direct_global_callee_ty(package, item_id, &callee_expr.ty, controlled_layers) + .unwrap_or_else(|| callee_expr.ty.clone()), + ) + } + ConcreteCallable::Dynamic => return, + }; + + rewrite_item_callee_with_functor(package, callee_id, item_id, callee_ty, functor, assigner); +} + +/// Rewrites the argument tuple of a direct call whose callable argument +/// was a closure, splicing captured values into the argument layout. +/// +/// # Before +/// ```text +/// original_args : OriginalInputTy +/// ``` +/// # After +/// ```text +/// (capture_0, ..., capture_n, original_args) : (CaptureTys..., OriginalInputTy) +/// ``` +/// +/// # Mutations +/// - Rewrites `args_id`'s `ExprKind` and `Ty` in place to a `Tuple` +/// containing capture expressions followed by the original args. +/// - Allocates capture `Expr` nodes through `assigner`. +/// - For controlled operations, recurses through control-qubit layers. +fn rewrite_direct_closure_args( + package: &mut Package, + args_id: ExprId, + captures: &[CapturedVar], + controlled_layers: usize, + assigner: &mut Assigner, +) { + if controlled_layers > 0 { + let inner_id = match package.get_expr(args_id).kind { + ExprKind::Tuple(ref elements) if elements.len() > 1 => elements[1], + _ => return, + }; + rewrite_direct_closure_args(package, inner_id, captures, controlled_layers - 1, assigner); + let inner_ty = package.get_expr(inner_id).ty.clone(); + let args_mut = package.exprs.get_mut(args_id).expect("args expr not found"); + if let Ty::Tuple(ref mut tys) = args_mut.ty + && tys.len() > 1 + { + tys[1] = inner_ty; + } + return; + } + + let args_expr = package.get_expr(args_id).clone(); + let capture_ids = allocate_capture_exprs(package, args_expr.span, captures, assigner); + let capture_tys: Vec = captures.iter().map(|capture| capture.ty.clone()).collect(); + + let preserved_args_id = assigner.next_expr(); + package.exprs.insert( + preserved_args_id, + Expr { + id: preserved_args_id, + span: args_expr.span, + ty: args_expr.ty.clone(), + kind: args_expr.kind, + exec_graph_range: EMPTY_EXEC_RANGE, + }, + ); + + let mut new_elements = capture_ids; + new_elements.push(preserved_args_id); + let mut new_tys = capture_tys; + new_tys.push(args_expr.ty); + + let args_mut = package.exprs.get_mut(args_id).expect("args expr not found"); + args_mut.kind = ExprKind::Tuple(new_elements); + args_mut.ty = Ty::Tuple(new_tys); +} + +/// Builds the arrow type for a direct call to a global specialized target, +/// matching the caller's expected signature after controlled-layer peeling. +fn build_direct_global_callee_ty( + package: &Package, + item_id: ItemId, + callee_ty: &Ty, + controlled_layers: usize, +) -> Option { + let Ty::Arrow(arrow) = callee_ty else { + return None; + }; + let ItemKind::Callable(decl) = &package.get_item(item_id.item).kind else { + return None; + }; + let target_input = package.get_pat(decl.input).ty.clone(); + let new_input = + apply_target_input_at_control_path(&arrow.input, &target_input, controlled_layers); + + Some(Ty::Arrow(Box::new(Arrow { + kind: arrow.kind, + input: Box::new(new_input), + output: arrow.output.clone(), + functors: arrow.functors, + }))) +} + +/// Replaces the innermost input slot beneath `controlled_layers` nested +/// controlled-operation tuples with `target_input`, returning the rewritten +/// outer type. +/// +/// A copy of this helper also lives in +/// [`super::specialize::apply_target_input_at_control_path`]; keep the two +/// in sync when changing controlled-layer handling (see the module-level +/// note for why both copies exist). +fn apply_target_input_at_control_path( + current_input: &Ty, + target_input: &Ty, + controlled_layers: usize, +) -> Ty { + if controlled_layers == 0 { + return target_input.clone(); + } + + match current_input { + Ty::Tuple(items) if items.len() > 1 => { + let mut new_items = items.clone(); + new_items[1] = apply_target_input_at_control_path( + &new_items[1], + target_input, + controlled_layers - 1, + ); + Ty::Tuple(new_items) + } + _ => target_input.clone(), + } +} + +/// Returns the packaged input tuple type for a direct call to a lambda +/// target whose parameters live in a one-element tuple. +/// +/// Relies on the naming contract with the producer pass: lifted lambdas +/// that take a single tuple parameter are named with a leading `""` +/// prefix. Do not rename lambda items without updating this predicate. +fn direct_lambda_packaged_input(package: &Package, item_id: LocalItemId) -> Option { + let ItemKind::Callable(decl) = &package.get_item(item_id).kind else { + return None; + }; + + let input_ty = package.get_pat(decl.input).ty.clone(); + if decl.name.name.as_ref().starts_with("") + && matches!(&input_ty, Ty::Tuple(items) if items.len() == 1) + { + Some(input_ty) + } else { + None + } +} + +/// Builds a single direct-call branch for index-dispatch synthesis by +/// materializing the callee expression, argument tuple, and capture +/// splicing for one specialized callable. +/// +/// # Before +/// ```text +/// (no expression — branch does not yet exist) +/// ``` +/// # After +/// ```text +/// Call(Var(specialized_item), (captures..., args)) : result_ty +/// ``` +/// +/// # Mutations +/// - Allocates callee, args, and call `Expr` nodes through `assigner`. +#[allow(clippy::too_many_arguments)] +fn create_direct_branch_call( + package: &mut Package, + package_id: PackageId, + orig_callee: &Expr, + orig_args: &Expr, + span: Span, + result_ty: &Ty, + direct_call_site: &DirectCallSite, + assigner: &mut Assigner, +) -> ExprId { + let captures = match &direct_call_site.callable { + ConcreteCallable::Closure { captures, .. } => { + resolve_rewrite_captures(package, orig_callee.id, captures) + } + _ => Vec::new(), + }; + let (_, outer_functor) = peel_body_functors(package, orig_callee.id); + let controlled_layers = usize::from(outer_functor.controlled); + let package_direct_lambda_input = match &direct_call_site.callable { + ConcreteCallable::Global { item_id, .. } if item_id.package == package_id => { + direct_lambda_packaged_input(package, item_id.item) + } + _ => None, + }; + let package_direct_lambda = matches!( + package_direct_lambda_input.as_ref(), + Some(target_input) + if apply_target_input_at_control_path(&orig_args.ty, target_input, controlled_layers) + != orig_args.ty + ); + + let (item_id, functor, callee_ty) = match &direct_call_site.callable { + ConcreteCallable::Global { item_id, functor } => { + let callee_ty = if item_id.package == package_id + && package_direct_lambda_input.is_some() + { + build_direct_global_callee_ty(package, *item_id, &orig_callee.ty, controlled_layers) + .unwrap_or_else(|| orig_callee.ty.clone()) + } else { + orig_callee.ty.clone() + }; + (*item_id, *functor, callee_ty) + } + ConcreteCallable::Closure { + target, functor, .. + } => { + let item_id = ItemId { + package: package_id, + item: *target, + }; + ( + item_id, + *functor, + build_direct_global_callee_ty(package, item_id, &orig_callee.ty, controlled_layers) + .unwrap_or_else(|| orig_callee.ty.clone()), + ) + } + ConcreteCallable::Dynamic => return orig_callee.id, + }; + + let callee_id = + alloc_item_callee_expr_with_functor(package, span, item_id, &callee_ty, functor, assigner); + let (args_kind, args_ty) = build_direct_branch_args_data( + package, + orig_args, + &captures, + controlled_layers, + package_direct_lambda, + assigner, + ); + let args_id = assigner.next_expr(); + package.exprs.insert( + args_id, + Expr { + id: args_id, + span, + ty: args_ty, + kind: args_kind, + exec_graph_range: EMPTY_EXEC_RANGE, + }, + ); + + let call_id = assigner.next_expr(); + package.exprs.insert( + call_id, + Expr { + id: call_id, + span, + ty: result_ty.clone(), + kind: ExprKind::Call(callee_id, args_id), + exec_graph_range: EMPTY_EXEC_RANGE, + }, + ); + + call_id +} + +/// Assembles the argument-tuple expressions for a direct-call branch, +/// including any capture values that must accompany a closure branch. +fn build_direct_branch_args_data( + package: &mut Package, + orig_args: &Expr, + captures: &[CapturedVar], + controlled_layers: usize, + package_direct_lambda: bool, + assigner: &mut Assigner, +) -> (ExprKind, Ty) { + if controlled_layers > 0 { + let ExprKind::Tuple(elements) = &orig_args.kind else { + return build_direct_branch_args_data( + package, + orig_args, + captures, + 0, + package_direct_lambda, + assigner, + ); + }; + let Ty::Tuple(tys) = &orig_args.ty else { + return build_direct_branch_args_data( + package, + orig_args, + captures, + 0, + package_direct_lambda, + assigner, + ); + }; + if elements.len() < 2 || tys.len() < 2 { + return build_direct_branch_args_data( + package, + orig_args, + captures, + 0, + package_direct_lambda, + assigner, + ); + } + + let inner_orig = package.get_expr(elements[1]).clone(); + let (inner_kind, inner_ty) = build_direct_branch_args_data( + package, + &inner_orig, + captures, + controlled_layers - 1, + package_direct_lambda, + assigner, + ); + + let inner_id = assigner.next_expr(); + package.exprs.insert( + inner_id, + Expr { + id: inner_id, + span: inner_orig.span, + ty: inner_ty.clone(), + kind: inner_kind, + exec_graph_range: EMPTY_EXEC_RANGE, + }, + ); + + return ( + ExprKind::Tuple(vec![elements[0], inner_id]), + Ty::Tuple(vec![tys[0].clone(), inner_ty]), + ); + } + + if captures.is_empty() && !package_direct_lambda { + return (orig_args.kind.clone(), orig_args.ty.clone()); + } + + let capture_ids = allocate_capture_exprs(package, orig_args.span, captures, assigner); + let capture_tys: Vec = captures.iter().map(|capture| capture.ty.clone()).collect(); + + let preserved_args_id = assigner.next_expr(); + package.exprs.insert( + preserved_args_id, + Expr { + id: preserved_args_id, + span: orig_args.span, + ty: orig_args.ty.clone(), + kind: orig_args.kind.clone(), + exec_graph_range: EMPTY_EXEC_RANGE, + }, + ); + + let mut tuple_items = capture_ids; + tuple_items.push(preserved_args_id); + let mut tuple_tys = capture_tys; + tuple_tys.push(orig_args.ty.clone()); + + (ExprKind::Tuple(tuple_items), Ty::Tuple(tuple_tys)) +} + +/// Rewrites a single call site to use the specialized callable. +/// +/// # Before +/// ```text +/// Call(Var(hof_item), (callable_arg, other_args)) +/// ``` +/// # After +/// ```text +/// Call(Var(specialized_item), (other_args, captures...)) +/// ``` +/// +/// # Mutations +/// - Rewrites the callee via [`rewrite_specialized_callee`]. +/// - Rewrites args via [`rewrite_args`], removing the callable parameter +/// and appending closure captures. +fn rewrite_one( + package: &mut Package, + package_id: PackageId, + call_site: &CallSite, + param: &CallableParam, + spec_local_id: LocalItemId, + expr_owner_lookup: &FxHashMap, + assigner: &mut Assigner, +) { + let call_expr = package.get_expr(call_site.call_expr_id).clone(); + + let ExprKind::Call(callee_id, args_id) = call_expr.kind else { + return; + }; + + // Replace callee with the specialized callable reference + let spec_item_id = ItemId { + package: package_id, + item: spec_local_id, + }; + + // Build the new callee type: remove the callable param from the arrow input. + let input_path = callable_param_input_path(package, callee_id, param); + let new_callee_ty = + build_specialized_callee_ty(package, callee_id, &input_path, &call_site.callable_arg); + rewrite_specialized_callee(package, callee_id, spec_item_id, new_callee_ty, assigner); + + // Remove the callable argument from the args tuple + // Insert closure captures as extra arguments + let captures = match &call_site.callable_arg { + ConcreteCallable::Closure { captures, .. } => { + resolve_rewrite_captures(package, call_site.arg_expr_id, captures) + } + _ => Vec::new(), + }; + rewrite_args( + package, + call_site.call_expr_id, + args_id, + &input_path, + &captures, + expr_owner_lookup, + assigner, + ); +} + +/// Removes the callable argument selected by `param` from the call arguments +/// and appends closure captures when needed. +/// +/// # Before +/// ```text +/// (callable_arg, arg1, arg2) +/// ``` +/// # After +/// ```text +/// (arg1, arg2, capture0, ..., captureN) // callable_arg removed, captures appended +/// ``` +/// +/// # Mutations +/// - Rewrites `args_id`'s `ExprKind` and `Ty` in place. +/// - Allocates capture `Expr` nodes through `assigner`. +fn rewrite_args( + package: &mut Package, + call_expr_id: ExprId, + args_id: ExprId, + input_path: &[usize], + captures: &[CapturedVar], + expr_owner_lookup: &FxHashMap, + assigner: &mut Assigner, +) { + let args_expr = package + .exprs + .get(args_id) + .expect("args expr not found") + .clone(); + + if input_path.is_empty() { + rewrite_single_arg_root(package, args_id, captures, assigner); + } else if matches!(args_expr.kind, ExprKind::Tuple(_)) { + let owner_callable = expr_owner_lookup.get(&call_expr_id).copied(); + if input_path.len() == 1 { + rewrite_args_remove_tuple_element(package, args_id, input_path[0], captures, assigner); + } else { + rewrite_args_nested_tuple_input( + package, + owner_callable, + args_id, + input_path[0], + &input_path[1..], + captures, + assigner, + ); + } + } else { + rewrite_single_arg_nested( + package, + call_expr_id, + args_id, + input_path, + captures, + expr_owner_lookup, + assigner, + ); + } +} + +/// Removes a top-level element from a tuple-structured args expression and +/// appends any closure captures. +/// +/// # Before +/// ```text +/// (arg0, callable_arg, arg2) // param_index = 1 +/// ``` +/// # After +/// ```text +/// (arg0, arg2, capture0, ...) // element removed, captures appended +/// ``` +/// +/// # Mutations +/// - Rewrites `args_id`'s `ExprKind` and `Ty` in place. +/// - Flattens single-element tuples to scalars. +/// - Allocates capture `Expr` nodes through `assigner`. +fn rewrite_args_remove_tuple_element( + package: &mut Package, + args_id: ExprId, + param_index: usize, + captures: &[CapturedVar], + assigner: &mut Assigner, +) { + let args_expr = package + .exprs + .get(args_id) + .expect("args expr not found") + .clone(); + + match &args_expr.kind { + ExprKind::Tuple(elements) => { + let mut new_elements: Vec = elements + .iter() + .enumerate() + .filter(|(i, _)| *i != param_index) + .map(|(_, &id)| id) + .collect(); + + // Append capture expressions. + let capture_ids = allocate_capture_exprs(package, args_expr.span, captures, assigner); + new_elements.extend(capture_ids); + + // Rebuild the type. + let new_ty = + build_tuple_ty_without_path(package, &args_expr.ty, &[param_index], captures); + + if new_elements.len() == 1 && captures.is_empty() { + // Flatten single-element tuple to match remove_callable_param + // which flattens the declaration's input pattern. + let single_id = new_elements[0]; + let single_expr = package + .exprs + .get(single_id) + .expect("expr not found") + .clone(); + let args_mut = package.exprs.get_mut(args_id).expect("args expr not found"); + args_mut.kind = single_expr.kind; + args_mut.ty = single_expr.ty; + } else { + let args_mut = package.exprs.get_mut(args_id).expect("args expr not found"); + args_mut.kind = ExprKind::Tuple(new_elements); + args_mut.ty = new_ty; + } + } + _ => { + rewrite_single_arg_root(package, args_id, captures, assigner); + } + } +} + +/// Rewrites args for a nested callable inside a top-level tuple input slot. +/// Captures are appended to the top-level args tuple. +/// +/// # Before +/// ```text +/// (ctrl_qubits, (callable_arg, inner_arg)) // field_path = [0] +/// ``` +/// # After +/// ```text +/// (ctrl_qubits, (inner_arg), capture0, ...) // nested element removed +/// ``` +/// +/// # Mutations +/// - Rewrites the inner element via [`rewrite_local_single_arg_nested`] or +/// [`remove_element_at_path`], then updates the outer tuple's type. +/// - Allocates capture `Expr` nodes through `assigner`. +fn rewrite_args_nested_tuple_input( + package: &mut Package, + owner_callable: Option, + args_id: ExprId, + top_level_param: usize, + field_path: &[usize], + captures: &[CapturedVar], + assigner: &mut Assigner, +) { + let args_expr = package + .exprs + .get(args_id) + .expect("args expr not found") + .clone(); + + if let ExprKind::Tuple(elements) = &args_expr.kind { + let inner_id = elements[top_level_param]; + if !rewrite_local_single_arg_nested( + package, + owner_callable, + inner_id, + field_path, + &[], + assigner, + ) { + // Remove the nested element from the inner tuple. + remove_element_at_path(package, inner_id, field_path); + } + + // Read the updated inner type before mutably borrowing the outer. + let inner_ty = package + .exprs + .get(inner_id) + .expect("expr not found") + .ty + .clone(); + + // Append captures to the top-level tuple if any. + if captures.is_empty() { + // Update the outer tuple's type for the modified inner element. + let args_mut = package.exprs.get_mut(args_id).expect("args expr not found"); + if let Ty::Tuple(ref mut tys) = args_mut.ty { + tys[top_level_param] = inner_ty; + } + } else { + let capture_ids = allocate_capture_exprs(package, args_expr.span, captures, assigner); + let capture_tys: Vec = captures.iter().map(|c| c.ty.clone()).collect(); + let args_mut = package.exprs.get_mut(args_id).expect("args expr not found"); + if let ExprKind::Tuple(ref mut elems) = args_mut.kind { + elems.extend(capture_ids); + } + if let Ty::Tuple(ref mut tys) = args_mut.ty { + tys[top_level_param] = inner_ty; + tys.extend(capture_tys); + } + } + } +} + +/// Rewrites args when the callable is nested inside the single argument value. +/// +/// # Before +/// ```text +/// args = local_udt // UDT/tuple containing callable at field_path +/// ``` +/// # After +/// ```text +/// args = (remaining_fields, captures...) // callable field removed +/// ``` +/// +/// # Mutations +/// - Delegates to [`rewrite_local_single_arg_nested`] when the arg is a +/// local whose initializer can be decomposed, otherwise falls back to +/// [`remove_element_at_path`]. +/// - Allocates capture `Expr` nodes through `assigner`. +fn rewrite_single_arg_nested( + package: &mut Package, + call_expr_id: ExprId, + args_id: ExprId, + field_path: &[usize], + captures: &[CapturedVar], + expr_owner_lookup: &FxHashMap, + assigner: &mut Assigner, +) { + if rewrite_local_single_arg_nested( + package, + expr_owner_lookup.get(&call_expr_id).copied(), + args_id, + field_path, + captures, + assigner, + ) { + return; + } + + remove_element_at_path(package, args_id, field_path); + if !captures.is_empty() { + let span = package.get_expr(args_id).span; + let capture_ids = allocate_capture_exprs(package, span, captures, assigner); + let modified_expr = package.exprs.get(args_id).expect("expr not found").clone(); + let mut new_elements = match &modified_expr.kind { + ExprKind::Tuple(elems) => elems.clone(), + _ => vec![args_id], + }; + new_elements.extend(capture_ids); + let capture_tys: Vec = captures.iter().map(|c| c.ty.clone()).collect(); + let mut new_tys = match &modified_expr.ty { + Ty::Tuple(tys) => tys.clone(), + ty => vec![ty.clone()], + }; + new_tys.extend(capture_tys); + let args_mut = package.exprs.get_mut(args_id).expect("args expr not found"); + args_mut.kind = ExprKind::Tuple(new_elements); + args_mut.ty = Ty::Tuple(new_tys); + } +} + +/// Rewrites a single local UDT/tuple argument by replacing the argument use with +/// the local initializer after removing the specialized callable field. +/// +/// # Before +/// ```text +/// args = Var(local_udt) // bound to (field0, callable, field2) +/// ``` +/// # After +/// ```text +/// args = (field0, field2, captures...) // callable field removed +/// ``` +/// +/// # Mutations +/// - Overwrites `args_id`'s `ExprKind` and `Ty` in place. +/// - Allocates capture `Expr` nodes through `assigner`. +fn rewrite_local_single_arg_nested( + package: &mut Package, + owner_callable: Option, + args_id: ExprId, + field_path: &[usize], + captures: &[CapturedVar], + assigner: &mut Assigner, +) -> bool { + if field_path.len() != 1 { + return false; + } + + let ExprKind::Var(Res::Local(local_var), _) = package.get_expr(args_id).kind else { + return false; + }; + let Some(owner_callable) = owner_callable else { + return false; + }; + let Some(init_expr_id) = find_local_init_expr_in_callable(package, owner_callable, local_var) + else { + return false; + }; + let Some((kind, ty)) = remove_top_level_field_from_expr_data( + package, + init_expr_id, + field_path[0], + captures, + assigner, + ) else { + return false; + }; + + let args_expr = package.exprs.get_mut(args_id).expect("args expr not found"); + args_expr.kind = kind; + args_expr.ty = ty; + true +} + +/// Builds replacement expression data for a call-argument aggregate after the +/// top-level callable field has been removed. +/// +/// Before, the tuple or struct represented by `expr_id` still contains the +/// callable-valued field selected by `field_index`. After, the returned +/// `ExprKind`/`Ty` pair describes the same aggregate with that field removed, +/// collapsed when only one element remains, and widened with any closure +/// captures that must become explicit call arguments. +fn remove_top_level_field_from_expr_data( + package: &mut Package, + expr_id: ExprId, + field_index: usize, + captures: &[CapturedVar], + assigner: &mut Assigner, +) -> Option<(ExprKind, Ty)> { + let expr = package.get_expr(expr_id).clone(); + let mut remaining = match &expr.kind { + ExprKind::Call(_, args_id) => { + return remove_top_level_field_from_expr_data( + package, + *args_id, + field_index, + captures, + assigner, + ); + } + ExprKind::Tuple(elements) => elements + .iter() + .enumerate() + .filter(|(idx, _)| *idx != field_index) + .map(|(_, &expr_id)| expr_id) + .collect::>(), + ExprKind::Struct(_, _, fields) => fields + .iter() + .filter_map(|field| match &field.field { + Field::Path(path) if path.indices.first() != Some(&field_index) => { + Some(field.value) + } + _ => None, + }) + .collect::>(), + _ => return None, + }; + + remaining.extend(allocate_capture_exprs( + package, expr.span, captures, assigner, + )); + + Some(build_expr_data_from_elements(package, remaining)) +} + +fn build_expr_data_from_elements(package: &Package, elements: Vec) -> (ExprKind, Ty) { + match elements.as_slice() { + [] => (ExprKind::Tuple(Vec::new()), Ty::UNIT), + [single] => { + let expr = package.get_expr(*single); + (expr.kind.clone(), expr.ty.clone()) + } + _ => { + let tys = elements + .iter() + .map(|&expr_id| package.get_expr(expr_id).ty.clone()) + .collect(); + (ExprKind::Tuple(elements), Ty::Tuple(tys)) + } + } +} + +/// Rewrites a single-parameter call's args expression after the callable +/// argument has been removed. +/// +/// Before, `args_id` evaluates to the callable argument itself. After, it +/// evaluates to `()` for a plain global callee or to `(captures...)` when the +/// rewritten direct call must thread closure captures explicitly. +fn rewrite_single_arg_root( + package: &mut Package, + args_id: ExprId, + captures: &[CapturedVar], + assigner: &mut Assigner, +) { + let args_expr = package + .exprs + .get(args_id) + .expect("args expr not found") + .clone(); + + if captures.is_empty() { + let args_mut = package.exprs.get_mut(args_id).expect("args expr not found"); + args_mut.kind = ExprKind::Tuple(Vec::new()); + args_mut.ty = Ty::UNIT; + } else { + let capture_ids = allocate_capture_exprs(package, args_expr.span, captures, assigner); + let capture_tys: Vec = captures.iter().map(|c| c.ty.clone()).collect(); + let args_mut = package.exprs.get_mut(args_id).expect("args expr not found"); + args_mut.kind = ExprKind::Tuple(capture_ids); + args_mut.ty = Ty::Tuple(capture_tys); + } +} + +/// Removes the callable argument at `path` from a tuple-valued args expression +/// in place. +/// +/// Before, the tuple nesting rooted at `expr_id` still matches the original +/// higher-order callable input. After, the selected element is removed, empty +/// tuples become unit, and one-element tuples collapse so the remaining shape +/// matches the specialized callee's input. +fn remove_element_at_path(package: &mut Package, expr_id: ExprId, path: &[usize]) { + if path.is_empty() { + return; + } + let expr = package.exprs.get(expr_id).expect("expr not found").clone(); + + if path.len() == 1 { + if let ExprKind::Tuple(elements) = &expr.kind { + let new_elements: Vec = elements + .iter() + .enumerate() + .filter(|(i, _)| *i != path[0]) + .map(|(_, &id)| id) + .collect(); + let new_tys: Vec = if let Ty::Tuple(tys) = &expr.ty { + tys.iter() + .enumerate() + .filter(|(i, _)| *i != path[0]) + .map(|(_, t)| t.clone()) + .collect() + } else { + Vec::new() + }; + + if new_elements.len() == 1 { + // Flatten single-element tuple. + let single = package + .exprs + .get(new_elements[0]) + .expect("expr not found") + .clone(); + let expr_mut = package.exprs.get_mut(expr_id).expect("expr not found"); + expr_mut.kind = single.kind; + expr_mut.ty = single.ty; + } else if new_elements.is_empty() { + let expr_mut = package.exprs.get_mut(expr_id).expect("expr not found"); + expr_mut.kind = ExprKind::Tuple(Vec::new()); + expr_mut.ty = Ty::UNIT; + } else { + let expr_mut = package.exprs.get_mut(expr_id).expect("expr not found"); + expr_mut.kind = ExprKind::Tuple(new_elements); + expr_mut.ty = Ty::Tuple(new_tys); + } + } + } else if let ExprKind::Tuple(elements) = &expr.kind { + let inner_id = elements[path[0]]; + remove_element_at_path(package, inner_id, &path[1..]); + // Update the outer tuple's type for the modified inner element. + let inner_expr = package.exprs.get(inner_id).expect("expr not found"); + let inner_ty = inner_expr.ty.clone(); + let expr_mut = package.exprs.get_mut(expr_id).expect("expr not found"); + if let Ty::Tuple(ref mut tys) = expr_mut.ty { + tys[path[0]] = inner_ty; + } + } +} + +/// Materializes the capture operands that must be appended to rewritten call +/// arguments. +/// +/// Before, each capture is represented only by analysis metadata: an optional +/// existing `ExprId` and the local it denotes. After, every capture has a +/// concrete `ExprId` that can be spliced into a tuple, reusing the recorded +/// expression when possible and otherwise synthesizing `Var(Local(_))` nodes. +fn allocate_capture_exprs( + package: &mut Package, + span: Span, + captures: &[CapturedVar], + assigner: &mut Assigner, +) -> Vec { + if captures.is_empty() { + return Vec::new(); + } + + let mut ids = Vec::with_capacity(captures.len()); + + for capture in captures { + if let Some(expr_id) = capture.expr { + ids.push(expr_id); + continue; + } + + let new_id = assigner.next_expr(); + let new_expr = Expr { + id: new_id, + span, + ty: capture.ty.clone(), + kind: ExprKind::Var(Res::Local(capture.var), Vec::new()), + exec_graph_range: EMPTY_EXEC_RANGE, + }; + package.exprs.insert(new_id, new_expr); + ids.push(new_id); + } + + ids +} + +/// Computes the callee arrow type that corresponds to a rewritten direct call. +/// +/// Before, the callee type still includes the callable-valued parameter from +/// the original higher-order signature. After, the returned arrow removes that +/// input slot and appends any closure capture types so the callee type matches +/// the rewritten args expression. +fn build_specialized_callee_ty( + package: &Package, + callee_id: ExprId, + input_path: &[usize], + concrete: &ConcreteCallable, +) -> Option { + let callee_expr = package.get_expr(callee_id); + let Ty::Arrow(ref arrow) = callee_expr.ty else { + return None; + }; + + let captures = match concrete { + ConcreteCallable::Closure { captures, .. } => captures.as_slice(), + _ => &[], + }; + + let new_input = remove_ty_at_path(package, &arrow.input, input_path, captures); + Some(Ty::Arrow(Box::new(Arrow { + kind: arrow.kind, + input: Box::new(new_input), + output: arrow.output.clone(), + functors: arrow.functors, + }))) +} + +/// Removes the type at a given path from a tuple type and appends capture types. +/// For single-element paths, removes the element at that index from the tuple. +/// For multi-element paths, navigates into nested tuples to remove the element. +/// An empty path removes the entire root value. If the type is not a tuple, +/// it represents the single callable-param case, so the result is either Unit +/// or a tuple of capture types. +fn remove_ty_at_path(package: &Package, ty: &Ty, path: &[usize], captures: &[CapturedVar]) -> Ty { + let capture_tys: Vec = captures.iter().map(|c| c.ty.clone()).collect(); + + if path.is_empty() { + return if capture_tys.is_empty() { + Ty::UNIT + } else { + Ty::Tuple(capture_tys) + }; + } + + let ty = resolve_udt_ty(package, ty); + + if path.len() == 1 { + if let Ty::Tuple(tys) = &ty { + let mut remaining: Vec = tys + .iter() + .enumerate() + .filter(|(i, _)| *i != path[0]) + .map(|(_, t)| t.clone()) + .collect(); + remaining.extend(capture_tys); + if remaining.is_empty() { + Ty::UNIT + } else if remaining.len() == 1 && captures.is_empty() { + // Flatten single-element tuple to match pattern flattening. + remaining + .into_iter() + .next() + .expect("single element should exist") + } else { + Ty::Tuple(remaining) + } + } else { + // Single param is the callable — result is captures or unit. + if capture_tys.is_empty() { + Ty::UNIT + } else { + Ty::Tuple(capture_tys) + } + } + } else { + // Navigate deeper: modify the sub-type at path[0], then rebuild. + if let Ty::Tuple(tys) = &ty { + let mut new_tys = tys.clone(); + // Remove nested element without captures at inner level. + new_tys[path[0]] = remove_ty_at_path(package, &tys[path[0]], &path[1..], &[]); + // Append captures at the top level. + new_tys.extend(capture_tys); + Ty::Tuple(new_tys) + } else { + // Single param that is a tuple type — remove from within. + let modified = remove_ty_at_path(package, &ty, &path[1..], &[]); + if capture_tys.is_empty() { + modified + } else { + let mut all = vec![modified]; + all.extend(capture_tys); + Ty::Tuple(all) + } + } + } +} + +/// Builds the tuple type for the args expression after removing the element at +/// `param_path` and appending capture types. +fn build_tuple_ty_without_path( + package: &Package, + ty: &Ty, + param_path: &[usize], + captures: &[CapturedVar], +) -> Ty { + remove_ty_at_path(package, ty, param_path, captures) +} + +fn local_ty_contains_arrow_through_udts(package: &Package, ty: &Ty) -> bool { + ty_contains_arrow(&resolve_udt_ty(package, ty)) +} + +fn resolve_udt_ty(package: &Package, ty: &Ty) -> Ty { + match ty { + Ty::Udt(Res::Item(item_id)) => { + let Some(item) = package.items.get(item_id.item) else { + return ty.clone(); + }; + let ItemKind::Ty(_, udt) = &item.kind else { + return ty.clone(); + }; + resolve_udt_ty(package, &udt.get_pure_ty()) + } + Ty::Tuple(elems) => Ty::Tuple( + elems + .iter() + .map(|elem| resolve_udt_ty(package, elem)) + .collect(), + ), + Ty::Array(elem) => Ty::Array(Box::new(resolve_udt_ty(package, elem))), + Ty::Arrow(arrow) => Ty::Arrow(Box::new(Arrow { + kind: arrow.kind, + input: Box::new(resolve_udt_ty(package, &arrow.input)), + output: Box::new(resolve_udt_ty(package, &arrow.output)), + functors: arrow.functors, + })), + _ => ty.clone(), + } +} + +fn callable_uses_tuple_input_pattern(package: &Package, callable_id: LocalItemId) -> bool { + let item = package.get_item(callable_id); + match &item.kind { + ItemKind::Callable(decl) => matches!(package.get_pat(decl.input).kind, PatKind::Tuple(_)), + _ => false, + } +} + +fn callable_param_input_path( + package: &Package, + callee_id: ExprId, + param: &CallableParam, +) -> Vec { + let (_, outer_functor) = peel_body_functors(package, callee_id); + let uses_tuple = callable_uses_tuple_input_pattern(package, param.callable_id); + super::build_param_input_path(uses_tuple, param, outer_functor) +} + +/// Replaces `callee_id` with a reference to the specialized callable while +/// preserving any outer functor shell. +/// +/// Before, the callee subtree still refers to the original higher-order item. +/// After, the same root `ExprId` evaluates the specialized callable and carries +/// the rewritten arrow type expected by the direct-call args. +fn rewrite_specialized_callee( + package: &mut Package, + callee_id: ExprId, + spec_item_id: ItemId, + new_callee_ty: Option, + assigner: &mut Assigner, +) { + let (_, outer_functor) = peel_body_functors(package, callee_id); + let callee_expr = package.get_expr(callee_id).clone(); + let callee_ty = new_callee_ty.unwrap_or_else(|| callee_expr.ty.clone()); + + rewrite_item_callee_with_functor( + package, + callee_id, + spec_item_id, + callee_ty, + outer_functor, + assigner, + ); +} + +/// Overwrites `callee_id` so it names `item_id`, rebuilding any `Adj`/`Ctl` +/// wrapper chain around a fresh inner `Var` expression. +/// +/// # Before +/// ```text +/// Ctl(Adj(Var(original_item))) : OldArrow +/// ``` +/// # After +/// ```text +/// Ctl(Adj(Var(specialized_item))) : NewArrow +/// ``` +/// +/// # Mutations +/// - Rewrites `callee_id`'s `ExprKind` and `Ty` in place. +/// - Allocates fresh inner `Var` and functor-wrapper `Expr` nodes through +/// `assigner` when the functor chain is non-trivial. +fn rewrite_item_callee_with_functor( + package: &mut Package, + callee_id: ExprId, + item_id: ItemId, + callee_ty: Ty, + functor: FunctorApp, + assigner: &mut Assigner, +) { + let callee_expr = package.get_expr(callee_id).clone(); + + if !functor.adjoint && functor.controlled == 0 { + let expr = package + .exprs + .get_mut(callee_id) + .expect("callee expr not found"); + expr.kind = ExprKind::Var(Res::Item(item_id), Vec::new()); + expr.ty = callee_ty; + return; + } + + // Rebuild the functor wrapper chain from the inside out, then copy the + // outermost node back into the original callee slot. + let mut current_id = assigner.next_expr(); + package.exprs.insert( + current_id, + Expr { + id: current_id, + span: callee_expr.span, + ty: callee_ty.clone(), + kind: ExprKind::Var(Res::Item(item_id), Vec::new()), + exec_graph_range: EMPTY_EXEC_RANGE, + }, + ); + + if functor.adjoint { + let adj_id = assigner.next_expr(); + package.exprs.insert( + adj_id, + Expr { + id: adj_id, + span: callee_expr.span, + ty: callee_ty.clone(), + kind: ExprKind::UnOp(UnOp::Functor(Functor::Adj), current_id), + exec_graph_range: EMPTY_EXEC_RANGE, + }, + ); + current_id = adj_id; + } + + for _ in 0..functor.controlled { + let ctl_id = assigner.next_expr(); + package.exprs.insert( + ctl_id, + Expr { + id: ctl_id, + span: callee_expr.span, + ty: callee_ty.clone(), + kind: ExprKind::UnOp(UnOp::Functor(Functor::Ctl), current_id), + exec_graph_range: EMPTY_EXEC_RANGE, + }, + ); + current_id = ctl_id; + } + + let outermost_kind = package + .exprs + .get(current_id) + .expect("specialized callee wrapper should exist") + .kind + .clone(); + let expr = package + .exprs + .get_mut(callee_id) + .expect("callee expr not found"); + expr.kind = outermost_kind; + expr.ty = callee_ty; +} + +/// Rewrites a call site that has multiple callee candidates (from branch-split +/// analysis) into an if/elif/else dispatch chain where each branch calls the +/// appropriate specialization. +/// +/// # Before +/// ```text +/// Call(Var(hof), (callable_arg, other_args)) +/// ``` +/// # After +/// ```text +/// if cond_0 { Call(Var(spec_0), args_0) } +/// elif cond_1 { Call(Var(spec_1), args_1) } +/// else { Call(Var(spec_default), args_default) } +/// ``` +/// +/// # Mutations +/// - Replaces `call_expr_id`'s `ExprKind` with the dispatch chain. +/// - Allocates per-branch `Call`, callee, args, and `If` `Expr` nodes +/// through `assigner`. +#[allow(clippy::too_many_lines)] +fn branch_split_rewrite( + package: &mut Package, + package_id: PackageId, + call_expr_id: ExprId, + entries: &[(&CallSite, LocalItemId, &CallableParam)], + expr_owner_lookup: &FxHashMap, + assigner: &mut Assigner, +) { + let orig_call = package.get_expr(call_expr_id).clone(); + let ExprKind::Call(orig_callee_id, orig_args_id) = orig_call.kind else { + return; + }; + let span = orig_call.span; + let result_ty = orig_call.ty.clone(); + + let mut conditioned: Vec<((&CallSite, LocalItemId, &CallableParam), ExprId)> = Vec::new(); + let mut default: Option<(&CallSite, LocalItemId, &CallableParam)> = None; + for &entry in entries { + if let Some(condition) = entry.0.condition { + conditioned.push((entry, condition)); + } else if default.is_none() { + default = Some(entry); + } + } + + if conditioned.is_empty() + && entries.len() > 1 + && let Some((synthetic_conditioned, default_idx)) = synthesize_callsite_index_dispatch( + package, + expr_owner_lookup, + call_expr_id, + entries, + span, + assigner, + ) + { + conditioned = synthetic_conditioned + .into_iter() + .map(|(entry_idx, condition)| (entries[entry_idx], condition)) + .collect(); + default = Some(entries[default_idx]); + } + + // Must have a default for the else branch; steal last conditioned if needed. + let default_entry = if let Some(d) = default { + d + } else { + if conditioned.is_empty() { + return; + } + conditioned.pop().expect("non-empty conditioned").0 + }; + + if conditioned.is_empty() { + // Single effective entry — use normal rewrite. + rewrite_one( + package, + package_id, + default_entry.0, + default_entry.2, + default_entry.1, + expr_owner_lookup, + assigner, + ); + return; + } + + // Clone original callee and args expressions before modifications. + let orig_callee = package.get_expr(orig_callee_id).clone(); + let orig_args = package.get_expr(orig_args_id).clone(); + + // Create the else (default) branch call. + let else_call_id = create_branch_call( + package, + package_id, + &orig_callee, + &orig_args, + span, + &result_ty, + default_entry.0, + default_entry.2, + default_entry.1, + assigner, + ); + + // Build the if/elif chain from the bottom up. + let mut current_else = else_call_id; + for ((cs, spec_id, param), cond_id) in conditioned.into_iter().rev() { + let branch_call_id = create_branch_call( + package, + package_id, + &orig_callee, + &orig_args, + span, + &result_ty, + cs, + param, + spec_id, + assigner, + ); + current_else = alloc_if_expr( + package, + span, + &result_ty, + cond_id, + branch_call_id, + current_else, + assigner, + ); + } + + // Replace the original call expression with the dispatch chain. + let dispatch = package + .exprs + .get(current_else) + .expect("dispatch expr should exist") + .clone(); + let orig = package + .exprs + .get_mut(call_expr_id) + .expect("call expr should exist"); + orig.kind = dispatch.kind; + orig.ty = dispatch.ty; +} + +/// Creates a single branch's specialised call expression, returning its +/// [`ExprId`]. The callee is replaced with the specialization, the callable +/// argument is removed from the args, and closure captures are appended. +/// +/// # Before +/// ```text +/// (no expression — branch does not yet exist) +/// ``` +/// # After +/// ```text +/// Call(Var(spec_item), (remaining_args, captures...)) : result_ty +/// ``` +/// +/// # Mutations +/// - Allocates callee, args, and call `Expr` nodes through `assigner`. +#[allow(clippy::too_many_arguments)] +fn create_branch_call( + package: &mut Package, + package_id: PackageId, + orig_callee: &Expr, + orig_args: &Expr, + span: Span, + result_ty: &Ty, + call_site: &CallSite, + param: &CallableParam, + spec_local_id: LocalItemId, + assigner: &mut Assigner, +) -> ExprId { + let spec_item_id = ItemId { + package: package_id, + item: spec_local_id, + }; + + // Specialised callee type. + let input_path = callable_param_input_path(package, orig_callee.id, param); + let new_callee_ty = build_specialized_callee_ty_from_expr( + package, + orig_callee, + &input_path, + &call_site.callable_arg, + ); + let callee_id = alloc_specialized_callee_expr( + package, + orig_callee, + spec_item_id, + &new_callee_ty.unwrap_or_else(|| orig_callee.ty.clone()), + assigner, + ); + + // Build args: remove callable param + append captures. + let captures = match &call_site.callable_arg { + ConcreteCallable::Closure { captures, .. } => { + resolve_rewrite_captures(package, call_site.arg_expr_id, captures) + } + _ => Vec::new(), + }; + let (args_kind, args_ty) = + build_branch_args_data(package, orig_args, &input_path, &captures, span, assigner); + + let args_id = assigner.next_expr(); + package.exprs.insert( + args_id, + Expr { + id: args_id, + span, + ty: args_ty, + kind: args_kind, + exec_graph_range: EMPTY_EXEC_RANGE, + }, + ); + + // Call expression. + let call_id = assigner.next_expr(); + package.exprs.insert( + call_id, + Expr { + id: call_id, + span, + ty: result_ty.clone(), + kind: ExprKind::Call(callee_id, args_id), + exec_graph_range: EMPTY_EXEC_RANGE, + }, + ); + + call_id +} + +/// Resolves the defining expressions for the captures referenced in a +/// direct-call rewrite, using the combined call-argument and block-scope +/// lookups. +fn resolve_rewrite_captures( + package: &Package, + arg_expr_id: ExprId, + captures: &[CapturedVar], +) -> Vec { + captures + .iter() + .map(|capture| { + let mut resolved = capture.clone(); + if resolved.expr.is_none() { + resolved.expr = resolve_capture_expr_from_arg(package, arg_expr_id, capture.var); + } + resolved + }) + .collect() +} + +/// Resolves a capture expression by inspecting the call's argument tuple, +/// used when the capture was passed in directly at the call site. +fn resolve_capture_expr_from_arg( + package: &Package, + arg_expr_id: ExprId, + capture_var: LocalVarId, +) -> Option { + let expr = package.get_expr(arg_expr_id); + match &expr.kind { + ExprKind::Block(block_id) => { + resolve_capture_expr_from_block(package, *block_id, capture_var) + } + ExprKind::If(_, body, otherwise) => { + resolve_capture_expr_from_arg(package, *body, capture_var).or_else(|| { + otherwise.and_then(|else_id| { + resolve_capture_expr_from_arg(package, else_id, capture_var) + }) + }) + } + ExprKind::UnOp(_, inner) => resolve_capture_expr_from_arg(package, *inner, capture_var), + _ => None, + } +} + +/// Resolves a capture expression by looking up the capture's defining +/// binding in the enclosing block's local-expression map. +fn resolve_capture_expr_from_block( + package: &Package, + block_id: qsc_fir::fir::BlockId, + capture_var: LocalVarId, +) -> Option { + let block = package.get_block(block_id); + let mut bindings = FxHashMap::default(); + + for stmt_id in &block.stmts { + let stmt = package.get_stmt(*stmt_id); + if let StmtKind::Local(_, pat_id, init_expr_id) = &stmt.kind { + collect_block_local_exprs(package, *pat_id, *init_expr_id, &mut bindings); + } + } + + let mut current = capture_var; + for _ in 0..32 { + let &expr_id = bindings.get(¤t)?; + let expr = package.get_expr(expr_id); + if let ExprKind::Var(Res::Local(next_var), _) = &expr.kind + && *next_var != current + && bindings.contains_key(next_var) + { + current = *next_var; + continue; + } + return Some(expr_id); + } + + None +} + +/// Builds a `LocalVarId → ExprId` map from a block's statements, capturing +/// the initializer expressions for every immutable local binding. +fn collect_block_local_exprs( + package: &Package, + pat_id: qsc_fir::fir::PatId, + init_expr_id: ExprId, + bindings: &mut FxHashMap, +) { + let pat = package.get_pat(pat_id); + match &pat.kind { + PatKind::Bind(ident) => { + bindings.insert(ident.id, init_expr_id); + } + PatKind::Discard => {} + PatKind::Tuple(pats) => { + for &sub_pat_id in pats { + collect_block_local_exprs(package, sub_pat_id, init_expr_id, bindings); + } + } + } +} + +/// Builds the args `ExprKind` and `Ty` for a branch call by removing the +/// callable parameter and appending closure captures. +fn build_branch_args_data( + package: &mut Package, + orig_args: &Expr, + input_path: &[usize], + captures: &[CapturedVar], + span: Span, + assigner: &mut Assigner, +) -> (ExprKind, Ty) { + if input_path.is_empty() { + // Single-param HOF: the argument IS the callable param. + if captures.is_empty() { + (ExprKind::Tuple(Vec::new()), Ty::UNIT) + } else { + let capture_ids = allocate_capture_exprs(package, span, captures, assigner); + let capture_tys: Vec = captures.iter().map(|c| c.ty.clone()).collect(); + (ExprKind::Tuple(capture_ids), Ty::Tuple(capture_tys)) + } + } else if matches!(orig_args.kind, ExprKind::Tuple(_)) { + match &orig_args.kind { + ExprKind::Tuple(elements) => { + if input_path.len() == 1 { + let mut new_elements: Vec = elements + .iter() + .enumerate() + .filter(|(i, _)| *i != input_path[0]) + .map(|(_, &id)| id) + .collect(); + let capture_ids = allocate_capture_exprs(package, span, captures, assigner); + new_elements.extend(capture_ids); + let new_ty = + build_tuple_ty_without_path(package, &orig_args.ty, input_path, captures); + // Flatten single-element tuple to match the flattening in + // rewrite_args_remove_tuple_element so the partial evaluator + // receives a scalar expression rather than a malformed 1-tuple. + if new_elements.len() == 1 && captures.is_empty() { + let single_id = new_elements[0]; + let single_expr = package.exprs.get(single_id).expect("expr not found"); + (single_expr.kind.clone(), single_expr.ty.clone()) + } else { + (ExprKind::Tuple(new_elements), new_ty) + } + } else { + let new_ty = + build_tuple_ty_without_path(package, &orig_args.ty, input_path, captures); + let mut new_kind = orig_args.kind.clone(); + if let ExprKind::Tuple(ref mut elems) = new_kind { + if let Some(outer_elem_id) = elems.get(input_path[0]).copied() { + remove_element_at_path(package, outer_elem_id, &input_path[1..]); + } + let capture_ids = allocate_capture_exprs(package, span, captures, assigner); + elems.extend(capture_ids); + } + (new_kind, new_ty) + } + } + _ => ( + orig_args.kind.clone(), + build_tuple_ty_without_path(package, &orig_args.ty, input_path, captures), + ), + } + } else if input_path.len() == 1 { + let param_index = input_path[0]; + match &orig_args.kind { + ExprKind::Tuple(elements) => { + let mut new_elements: Vec = elements + .iter() + .enumerate() + .filter(|(i, _)| *i != param_index) + .map(|(_, &id)| id) + .collect(); + let capture_ids = allocate_capture_exprs(package, span, captures, assigner); + new_elements.extend(capture_ids); + let new_ty = + build_tuple_ty_without_path(package, &orig_args.ty, input_path, captures); + // Flatten single-element tuple to match the flattening in + // rewrite_args_remove_tuple_element so the partial evaluator + // receives a scalar expression rather than a malformed 1-tuple. + if new_elements.len() == 1 && captures.is_empty() { + let single_id = new_elements[0]; + let single_expr = package.exprs.get(single_id).expect("expr not found"); + (single_expr.kind.clone(), single_expr.ty.clone()) + } else { + (ExprKind::Tuple(new_elements), new_ty) + } + } + _ => ( + orig_args.kind.clone(), + build_tuple_ty_without_path(package, &orig_args.ty, input_path, captures), + ), + } + } else { + // Nested path: rebuild both the args type and expression with the + // nested element removed. + remove_element_at_path(package, orig_args.id, input_path); + let new_ty = build_tuple_ty_without_path(package, &orig_args.ty, input_path, captures); + let modified_args = package.get_expr(orig_args.id).clone(); + let new_kind = if captures.is_empty() { + modified_args.kind + } else { + let capture_ids = allocate_capture_exprs(package, span, captures, assigner); + if let ExprKind::Tuple(mut elems) = modified_args.kind { + elems.extend(capture_ids); + ExprKind::Tuple(elems) + } else { + let mut elems = vec![orig_args.id]; + elems.extend(capture_ids); + ExprKind::Tuple(elems) + } + }; + (new_kind, new_ty) + } +} + +/// Allocates a fresh `Var` expression that references a specialized +/// callable item, returning its new `ExprId`. +/// +/// # Mutations +/// - Delegates to [`alloc_item_callee_expr_with_functor`], which inserts +/// the `Var` and any functor wrapper `Expr` nodes. +fn alloc_specialized_callee_expr( + package: &mut Package, + orig_callee: &Expr, + spec_item_id: ItemId, + callee_ty: &Ty, + assigner: &mut Assigner, +) -> ExprId { + let (_, outer_functor) = peel_body_functors(package, orig_callee.id); + alloc_item_callee_expr_with_functor( + package, + orig_callee.span, + spec_item_id, + callee_ty, + outer_functor, + assigner, + ) +} + +/// Allocates a fresh callee expression that wraps an item reference with +/// the requested functor applications (`Adj` and/or `Ctl` layers). +/// +/// # Before +/// ```text +/// (no expression) +/// ``` +/// # After +/// ```text +/// Ctl(...(Adj(Var(item_id)))) : callee_ty // functor chain built +/// ``` +/// +/// # Mutations +/// - Inserts one `Var` `Expr` and zero or more functor-wrapper `Expr` +/// nodes into `package` through `assigner`. +fn alloc_item_callee_expr_with_functor( + package: &mut Package, + span: Span, + item_id: ItemId, + callee_ty: &Ty, + functor: FunctorApp, + assigner: &mut Assigner, +) -> ExprId { + let mut current_id = assigner.next_expr(); + package.exprs.insert( + current_id, + Expr { + id: current_id, + span, + ty: callee_ty.clone(), + kind: ExprKind::Var(Res::Item(item_id), Vec::new()), + exec_graph_range: EMPTY_EXEC_RANGE, + }, + ); + + if functor.adjoint { + let adj_id = assigner.next_expr(); + package.exprs.insert( + adj_id, + Expr { + id: adj_id, + span, + ty: callee_ty.clone(), + kind: ExprKind::UnOp(UnOp::Functor(Functor::Adj), current_id), + exec_graph_range: EMPTY_EXEC_RANGE, + }, + ); + current_id = adj_id; + } + + for _ in 0..functor.controlled { + let ctl_id = assigner.next_expr(); + package.exprs.insert( + ctl_id, + Expr { + id: ctl_id, + span, + ty: callee_ty.clone(), + kind: ExprKind::UnOp(UnOp::Functor(Functor::Ctl), current_id), + exec_graph_range: EMPTY_EXEC_RANGE, + }, + ); + current_id = ctl_id; + } + + current_id +} + +/// Allocates a new `ExprKind::If` expression and inserts it into the package. +/// +/// # Before +/// ```text +/// (no expression) +/// ``` +/// # After +/// ```text +/// If(cond_id, true_id, Some(false_id)) : result_ty +/// ``` +/// +/// # Mutations +/// - Inserts one `Expr` node through `assigner`. +fn alloc_if_expr( + package: &mut Package, + span: Span, + result_ty: &Ty, + cond_id: ExprId, + true_id: ExprId, + false_id: ExprId, + assigner: &mut Assigner, +) -> ExprId { + let if_id = assigner.next_expr(); + package.exprs.insert( + if_id, + Expr { + id: if_id, + span, + ty: result_ty.clone(), + kind: ExprKind::If(cond_id, true_id, Some(false_id)), + exec_graph_range: EMPTY_EXEC_RANGE, + }, + ); + if_id +} + +/// Builds the specialised callee type from a saved callee expression snapshot. +fn build_specialized_callee_ty_from_expr( + package: &Package, + callee_expr: &Expr, + input_path: &[usize], + concrete: &ConcreteCallable, +) -> Option { + let Ty::Arrow(ref arrow) = callee_expr.ty else { + return None; + }; + let captures = match concrete { + ConcreteCallable::Closure { captures, .. } => captures.as_slice(), + _ => &[], + }; + let new_input = remove_ty_at_path(package, &arrow.input, input_path, captures); + Some(Ty::Arrow(Box::new(Arrow { + kind: arrow.kind, + input: Box::new(new_input), + output: arrow.output.clone(), + functors: arrow.functors, + }))) +} diff --git a/source/compiler/qsc_fir_transforms/src/defunctionalize/semantic_equivalence_tests.rs b/source/compiler/qsc_fir_transforms/src/defunctionalize/semantic_equivalence_tests.rs new file mode 100644 index 0000000000..35ba2fd79a --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/defunctionalize/semantic_equivalence_tests.rs @@ -0,0 +1,117 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use indoc::formatdoc; +use proptest::prelude::*; + +/// Generates syntactically valid Q# programs exercising defunctionalization's +/// key code paths: lambda arguments, partial application, and direct callable +/// references passed to higher-order functions. +fn defunc_pattern_strategy() -> impl Strategy { + let val = || 0..50i64; + + prop_oneof![ + // 1. Lambda passed as argument to a higher-order function. + (val(), val()).prop_map(|(a, b)| formatdoc! {" + namespace Test {{ + function Apply(f : Int -> Int, x : Int) : Int {{ f(x) }} + function Main() : Int {{ + Apply(x -> x + {a}, {b}) + }} + }} + "}), + // 2. Partial application of a two-argument function. + (val(), val()).prop_map(|(a, b)| formatdoc! {" + namespace Test {{ + function Add(x : Int, y : Int) : Int {{ x + y }} + function Apply(f : Int -> Int, x : Int) : Int {{ f(x) }} + function Main() : Int {{ + Apply(Add({a}, _), {b}) + }} + }} + "}), + // 3. Direct callable reference as argument. + val().prop_map(|a| formatdoc! {" + namespace Test {{ + function Double(x : Int) : Int {{ x * 2 }} + function Apply(f : Int -> Int, x : Int) : Int {{ f(x) }} + function Main() : Int {{ + Apply(Double, {a}) + }} + }} + "}), + // 4. Nested higher-order calls: function returning a lambda. + (val(), val()).prop_map(|(a, b)| formatdoc! {" + namespace Test {{ + function MakeAdder(n : Int) : Int -> Int {{ x -> x + n }} + function Apply(f : Int -> Int, x : Int) : Int {{ f(x) }} + function Main() : Int {{ + Apply(MakeAdder({a}), {b}) + }} + }} + "}), + ] +} + +/// Generates programs with multi-capture closures where the captures have +/// distinct values and are used in non-commutative operations, ensuring +/// capture ordering is exercised. +fn multi_capture_strategy() -> impl Strategy { + // Use distinct non-zero values so swapped captures produce a different result. + (2..20i64, 1..10i64) + .prop_filter("a must differ from b", |(a, b)| a != b && *b != 0) + .prop_flat_map(|(a, b)| { + prop_oneof![ + // Two captures used in non-commutative subtraction. + Just(formatdoc! {" + namespace Test {{ + function Apply(f : Int -> Int, x : Int) : Int {{ f(x) }} + function Main() : Int {{ + let a = {a}; + let b = {b}; + Apply(x -> a - b + x, 0) + }} + }} + "}), + // Two captures used in non-commutative division. + Just(formatdoc! {" + namespace Test {{ + function Apply(f : Int -> Int, x : Int) : Int {{ f(x) }} + function Main() : Int {{ + let a = {a}; + let b = {b}; + Apply(x -> a / b + x, 0) + }} + }} + "}), + // Three captures in position-sensitive expression. + Just(formatdoc! {" + namespace Test {{ + function Apply(f : Int -> Int, x : Int) : Int {{ f(x) }} + function Main() : Int {{ + let a = {a}; + let b = {b}; + let c = 1; + Apply(x -> (a - b) * c + x, 0) + }} + }} + "}), + ] + }) +} + +proptest! { + #![proptest_config(ProptestConfig::with_cases(50))] + #[test] + fn differential_defunctionalize(source in defunc_pattern_strategy()) { + crate::test_utils::check_semantic_equivalence(&source); + } +} + +proptest! { + #![proptest_config(ProptestConfig::with_cases(30))] + #[test] + fn differential_multi_capture_ordering(source in multi_capture_strategy()) { + crate::test_utils::check_semantic_equivalence(&source); + } +} diff --git a/source/compiler/qsc_fir_transforms/src/defunctionalize/specialize.rs b/source/compiler/qsc_fir_transforms/src/defunctionalize/specialize.rs new file mode 100644 index 0000000000..fc30200ec1 --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/defunctionalize/specialize.rs @@ -0,0 +1,2502 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! Specialization phase of the defunctionalization pass. +//! +//! For each call site where a higher-order function is invoked with a concrete +//! callable argument, this module clones the HOF body and transforms it so +//! that the callable parameter is replaced by a direct call to the concrete +//! callee. A deduplication map ensures that identical `SpecKey`s produce only +//! one specialization. +//! +//! # Post-transform retyping +//! +//! Cloning a HOF body replaces one or more indirect callable references +//! (typed as arrow) with direct item references (typed as the callable's +//! concrete signature). The surrounding expressions, statements, and blocks +//! that flowed those callable values still carry their pre-rewrite arrow +//! types, so a cascade of `refresh_*_types` helpers +//! ([`refresh_rewritten_value_types`], [`refresh_block_types`], +//! [`refresh_stmt_types`], [`refresh_expr_types`]) re-runs type propagation +//! across the cloned body to re-establish the +//! [`crate::invariants::InvariantLevel::PostDefunc`] invariant that no +//! arrow types appear on reachable callable parameters or expressions. + +use super::build_spec_key; +use super::types::{ + AnalysisResult, CallSite, CallableParam, CapturedVar, ConcreteCallable, Error, SpecKey, + compose_functors, peel_body_functors, +}; +use crate::EMPTY_EXEC_RANGE; +use crate::cloner::FirCloner; +use crate::fir_builder::functored_specs; +use qsc_data_structures::functors::FunctorApp; +use qsc_data_structures::span::Span; +use qsc_fir::assigner::Assigner; +use qsc_fir::fir::{ + CallableDecl, CallableImpl, Expr, ExprId, ExprKind, Field, FieldPath, Functor, Ident, Item, + ItemId, ItemKind, LocalItemId, LocalVarId, NodeId, Package, PackageId, PackageLookup, + PackageStore, Pat, PatId, PatKind, Res, UnOp, Visibility, +}; +use qsc_fir::ty::{Arrow, Ty}; +use rustc_hash::{FxHashMap, FxHashSet}; +use std::rc::Rc; + +/// Maximum number of specializations a single HOF may generate before a +/// warning diagnostic is emitted. Mirrors the LLVM `FuncSpec` `MaxClones` +/// budget, adapted as a diagnostic-only threshold. +const EXCESSIVE_SPECIALIZATION_THRESHOLD: usize = 10; + +/// Set of `LocalVarId`s that alias a nested callable parameter after +/// destructuring (e.g. `let (op, _) = pair;` makes `op` an alias). +type AliasSet = FxHashSet; + +/// Resolves a `ConcreteCallable` to a compact label for inclusion in +/// specialized callable names. For globals, produces the callable name +/// with a functor prefix when non-body (e.g. `H`, `Adj S`, `Ctl X`). +/// For closures, produces `closure`. +fn resolve_callable_arg_label(store: &PackageStore, arg: &ConcreteCallable) -> String { + match arg { + ConcreteCallable::Global { item_id, functor } => { + let pkg = store.get(item_id.package); + let item = pkg.get_item(item_id.item); + let base = if let ItemKind::Callable(decl) = &item.kind { + decl.name.name.to_string() + } else { + format!("Item({})", item_id.item) + }; + match (functor.adjoint, functor.controlled > 0) { + (false, false) => base, + (true, false) => format!("Adj {base}"), + (false, true) => format!("Ctl {base}"), + (true, true) => format!("CtlAdj {base}"), + } + } + ConcreteCallable::Closure { .. } => "closure".to_string(), + ConcreteCallable::Dynamic => "dynamic".to_string(), + } +} + +/// Specializes higher-order functions for each concrete callable argument +/// discovered during analysis. +/// +/// Returns a map from `SpecKey` to the `LocalItemId` of the newly created +/// specialized callable in the target package. +pub(super) fn specialize( + store: &mut PackageStore, + package_id: PackageId, + analysis: &AnalysisResult, + assigner: &mut Assigner, +) -> (FxHashMap, Vec) { + let mut dedup: FxHashMap = FxHashMap::default(); + let mut errors: Vec = Vec::new(); + let mut recursion_guard: FxHashSet = FxHashSet::default(); + + // Build a lookup from LocalItemId → CallableParam for quick access. + // Use entry().or_insert() to keep the first (lowest-index) param per + // callable, ensuring deterministic behavior when a callable has multiple + // arrow params. + let mut param_lookup: FxHashMap = FxHashMap::default(); + for p in &analysis.callable_params { + param_lookup.entry(p.callable_id).or_insert(p); + } + + for call_site in &analysis.call_sites { + let spec_key = build_spec_key(call_site); + + // Already specialized — skip. + if dedup.contains_key(&spec_key) { + continue; + } + + // Dynamic callables cannot be specialized — emit an error with the + // call-site span so the user gets an actionable diagnostic instead of + // the generic `FixpointNotReached` convergence error. + if matches!(call_site.callable_arg, ConcreteCallable::Dynamic) { + let package = store.get(package_id); + let span = package.get_expr(call_site.call_expr_id).span; + errors.push(Error::DynamicCallable(span)); + continue; + } + + // Skip cross-package HOFs that were NOT cloned into the user + // package by monomorphization. Cross-package HOFs that WERE cloned + // (e.g. generic std lib callables monomorphized with concrete types) + // now exist in the user package's items map and should be processed. + if call_site.hof_item_id.package != package_id { + let pkg = store.get(package_id); + if !pkg.items.contains_key(call_site.hof_item_id.item) { + continue; + } + } + + // Recursive specialization guard. + if recursion_guard.contains(&spec_key) { + let package = store.get(package_id); + let span = package.get_expr(call_site.call_expr_id).span; + errors.push(Error::RecursiveSpecialization(span)); + continue; + } + recursion_guard.insert(spec_key.clone()); + + let hof_local_item = call_site.hof_item_id.item; + + // Look up the callable parameter for this HOF. + let Some(param) = param_lookup.get(&hof_local_item).copied() else { + recursion_guard.remove(&spec_key); + continue; + }; + + // Clone the HOF and produce a specialized callable. + let new_item_id = specialize_one(store, package_id, call_site, param, assigner); + + if let Some(id) = new_item_id { + dedup.insert(spec_key.clone(), id); + } + + recursion_guard.remove(&spec_key); + } + + // Count specializations per HOF and emit a warning when the threshold + // is exceeded. Group dedup entries by the HOF callable_id embedded in + // each SpecKey. + let mut specs_per_hof: FxHashMap = FxHashMap::default(); + for key in dedup.keys() { + *specs_per_hof.entry(key.hof_id).or_default() += 1; + } + for (hof_id, count) in &specs_per_hof { + if *count > EXCESSIVE_SPECIALIZATION_THRESHOLD { + let package = store.get(package_id); + let item = package.get_item(*hof_id); + let (name, span) = if let ItemKind::Callable(decl) = &item.kind { + (decl.name.name.to_string(), decl.name.span) + } else { + (format!("Item({hof_id})"), Span::default()) + }; + errors.push(Error::ExcessiveSpecializations(name, *count, span)); + } + } + + (dedup, errors) +} + +/// Drives the post-transform retyping cascade across every spec impl of a +/// freshly cloned callable, re-establishing [`crate::invariants::InvariantLevel::PostDefunc`] +/// type consistency after callable references become direct. +/// +/// # Before +/// ```text +/// body { Expr.ty, Block.ty, Pat.ty may be stale } +/// ``` +/// # After +/// ```text +/// body { Expr.ty, Block.ty, Pat.ty refreshed from children up } +/// ``` +/// +/// # Mutations +/// - Rewrites `Expr.ty`, `Block.ty`, and `Pat.ty` in place across the +/// entire callable implementation. +fn refresh_rewritten_value_types(package: &mut Package, callable_impl: &CallableImpl) { + match callable_impl { + CallableImpl::Intrinsic => {} + CallableImpl::Spec(spec_impl) => { + refresh_block_types(package, spec_impl.body.block); + for spec in functored_specs(spec_impl) { + refresh_block_types(package, spec.block); + } + } + CallableImpl::SimulatableIntrinsic(spec) => { + refresh_block_types(package, spec.block); + } + } +} + +/// Re-computes the type of every statement in a block, returning the +/// refreshed trailing type so enclosing expressions can cascade the update. +/// +/// # Before +/// ```text +/// Block { stmts, ty: stale_ty } +/// ``` +/// # After +/// ```text +/// Block { stmts, ty: trailing_expr.ty } // or Unit if no trailing Expr +/// ``` +/// +/// # Mutations +/// - Rewrites `Block.ty` in place. +/// - Delegates to [`refresh_stmt_types`] for each statement. +fn refresh_block_types(package: &mut Package, block_id: qsc_fir::fir::BlockId) -> Ty { + let stmt_ids = package.get_block(block_id).stmts.clone(); + for stmt_id in stmt_ids { + refresh_stmt_types(package, stmt_id); + } + + let new_ty = package + .get_block(block_id) + .stmts + .last() + .and_then(|stmt_id| match package.get_stmt(*stmt_id).kind { + qsc_fir::fir::StmtKind::Expr(expr_id) => Some(package.get_expr(expr_id).ty.clone()), + qsc_fir::fir::StmtKind::Semi(_) + | qsc_fir::fir::StmtKind::Local(_, _, _) + | qsc_fir::fir::StmtKind::Item(_) => None, + }) + .unwrap_or(Ty::UNIT); + + let block = package.blocks.get_mut(block_id).expect("block not found"); + block.ty = new_ty.clone(); + new_ty +} + +/// Refreshes the type of a single statement and, when it introduces a +/// local binding, updates the bound pattern's type to match the rewritten +/// initializer. +/// +/// # Before +/// ```text +/// Local(pat: OldTy, init_expr: NewTy) +/// ``` +/// # After +/// ```text +/// Local(pat: NewTy, init_expr: NewTy) // pat retyped to match init +/// ``` +/// +/// # Mutations +/// - Rewrites `Pat.ty` in place for `Bind` and `Discard` patterns. +/// - Delegates to [`refresh_expr_types`] for the statement's expression. +fn refresh_stmt_types(package: &mut Package, stmt_id: qsc_fir::fir::StmtId) { + let stmt = package.get_stmt(stmt_id).clone(); + match stmt.kind { + qsc_fir::fir::StmtKind::Expr(expr_id) | qsc_fir::fir::StmtKind::Semi(expr_id) => { + let _ = refresh_expr_types(package, expr_id); + } + qsc_fir::fir::StmtKind::Local(_, pat_id, expr_id) => { + let expr_ty = refresh_expr_types(package, expr_id); + let pat_kind = package.get_pat(pat_id).kind.clone(); + if matches!(pat_kind, PatKind::Bind(_) | PatKind::Discard) { + let pat = package.pats.get_mut(pat_id).expect("pat not found"); + pat.ty = expr_ty; + } + } + qsc_fir::fir::StmtKind::Item(_) => {} + } +} + +/// Recomputes the type of an expression after rewriting, propagating the +/// refreshed type through nested blocks, conditionals, calls, and tuple +/// constructors. +/// +/// # Before +/// ```text +/// Expr { kind, ty: stale_ty } +/// ``` +/// # After +/// ```text +/// Expr { kind, ty: recomputed_ty } // derived from children +/// ``` +/// +/// # Mutations +/// - Rewrites `Expr.ty` in place. +/// - Recursively refreshes all reachable sub-expressions. +fn refresh_expr_types(package: &mut Package, expr_id: ExprId) -> Ty { + let expr = package.get_expr(expr_id).clone(); + let new_ty = match expr.kind { + ExprKind::Block(block_id) => refresh_block_types(package, block_id), + ExprKind::If(cond_id, body_id, otherwise_id) => { + let _ = refresh_expr_types(package, cond_id); + let body_ty = refresh_expr_types(package, body_id); + if let Some(otherwise_id) = otherwise_id { + let _ = refresh_expr_types(package, otherwise_id); + body_ty + } else { + Ty::UNIT + } + } + ExprKind::Tuple(items) => Ty::Tuple( + items + .into_iter() + .map(|item_id| refresh_expr_types(package, item_id)) + .collect(), + ), + ExprKind::Array(items) | ExprKind::ArrayLit(items) => { + let item_tys: Vec = items + .into_iter() + .map(|item_id| refresh_expr_types(package, item_id)) + .collect(); + if let Some(item_ty) = item_tys.first() { + Ty::Array(Box::new(item_ty.clone())) + } else { + expr.ty + } + } + ExprKind::ArrayRepeat(value_id, count_id) => { + let value_ty = refresh_expr_types(package, value_id); + let _ = refresh_expr_types(package, count_id); + Ty::Array(Box::new(value_ty)) + } + ExprKind::Assign(lhs_id, rhs_id) + | ExprKind::AssignOp(_, lhs_id, rhs_id) + | ExprKind::BinOp(_, lhs_id, rhs_id) + | ExprKind::Index(lhs_id, rhs_id) + | ExprKind::UpdateField(lhs_id, _, rhs_id) + | ExprKind::UpdateIndex(lhs_id, rhs_id, _) + | ExprKind::AssignField(lhs_id, _, rhs_id) + | ExprKind::AssignIndex(lhs_id, rhs_id, _) => { + let _ = refresh_expr_types(package, lhs_id); + let _ = refresh_expr_types(package, rhs_id); + expr.ty + } + ExprKind::While(cond_id, block_id) => { + let _ = refresh_expr_types(package, cond_id); + let _ = refresh_block_types(package, block_id); + expr.ty + } + ExprKind::Call(callee_id, args_id) => { + let _ = refresh_expr_types(package, callee_id); + let _ = refresh_expr_types(package, args_id); + expr.ty + } + ExprKind::UnOp(_, inner_id) + | ExprKind::Return(inner_id) + | ExprKind::Fail(inner_id) + | ExprKind::Field(inner_id, _) => { + let _ = refresh_expr_types(package, inner_id); + expr.ty + } + ExprKind::Range(start_id, step_id, end_id) => { + if let Some(start_id) = start_id { + let _ = refresh_expr_types(package, start_id); + } + if let Some(step_id) = step_id { + let _ = refresh_expr_types(package, step_id); + } + if let Some(end_id) = end_id { + let _ = refresh_expr_types(package, end_id); + } + expr.ty + } + ExprKind::String(components) => { + for component in components { + if let qsc_fir::fir::StringComponent::Expr(component_id) = component { + let _ = refresh_expr_types(package, component_id); + } + } + expr.ty + } + ExprKind::Struct(_, copy_id, fields) => { + if let Some(copy_id) = copy_id { + let _ = refresh_expr_types(package, copy_id); + } + for field in fields { + let _ = refresh_expr_types(package, field.value); + } + expr.ty + } + ExprKind::Closure(_, _) | ExprKind::Hole | ExprKind::Lit(_) | ExprKind::Var(_, _) => { + expr.ty + } + }; + + let expr_mut = package.exprs.get_mut(expr_id).expect("expr not found"); + expr_mut.ty = new_ty.clone(); + new_ty +} + +/// Clones a HOF callable, transforms its body to replace the callable +/// parameter with the concrete callee, and inserts the specialized callable +/// into the package. Returns the `LocalItemId` of the new item. +#[allow(clippy::too_many_lines)] +fn specialize_one( + store: &mut PackageStore, + package_id: PackageId, + call_site: &CallSite, + param: &CallableParam, + assigner: &mut Assigner, +) -> Option { + // Extract needed data from the source package. + // The HOF may live in a different package (e.g. the standard library), + // so use hof_item_id.package rather than the target package_id. + let hof_pkg_id = call_site.hof_item_id.package; + let arg_label = resolve_callable_arg_label(store, &call_site.callable_arg); + let (body_pkg, decl_snapshot) = { + let package = store.get(hof_pkg_id); + let hof_item = package.get_item(call_site.hof_item_id.item); + + let ItemKind::Callable(ref hof_decl) = hof_item.kind else { + return None; + }; + + let body_pkg = extract_callable_body(package, hof_decl); + let decl_snapshot = hof_decl.as_ref().clone(); + (body_pkg, decl_snapshot) + }; // immutable borrow released + + // Clone body into target, transform, insert. + let target = store.get_mut(package_id); + let new_item_id = assigner.next_item(); + let owned_assigner = std::mem::take(assigner); + let mut cloner = FirCloner::from_assigner(owned_assigner); + cloner.reset_maps(); + + // Clone input BEFORE impl so that `local_map` contains input parameter + // mappings when the callable body is walked. + let cloned_input = cloner.clone_pat(&body_pkg, decl_snapshot.input, target); + let cloned_impl = cloner.clone_callable_impl(&body_pkg, &decl_snapshot.implementation, target); + + // Input is cloned BEFORE the body (above), so `local_map` always + // contains the mapping for the original parameter variable. + let remapped_param_var = *cloner + .local_map() + .get(¶m.param_var) + .expect("param_var should be in local_map after cloning input first"); + + let remapped_param = CallableParam::new( + param.callable_id, + cloner + .pat_map() + .get(¶m.param_pat_id) + .copied() + .unwrap_or(param.param_pat_id), + param.top_level_param, + param.field_path.clone(), + remapped_param_var, + param.param_ty.clone(), + ); + + let hof_name: Rc = Rc::from(format!("{}{{{arg_label}}}", decl_snapshot.name.name)); + let mut new_decl = CallableDecl { + id: NodeId::from(0_u32), + span: decl_snapshot.span, + kind: decl_snapshot.kind, + name: Ident { + id: LocalVarId::from(0_u32), + span: decl_snapshot.name.span, + name: hof_name, + }, + generics: decl_snapshot.generics.clone(), + input: cloned_input, + output: decl_snapshot.output.clone(), + functors: decl_snapshot.functors, + implementation: cloned_impl, + attrs: decl_snapshot.attrs.clone(), + }; + + // Thread closure captures BEFORE recovering the assigner, since + // thread_closure_captures uses the cloner for pat/local allocation. + let closure_info = if let ConcreteCallable::Closure { + ref captures, + target: closure_target, + .. + } = call_site.callable_arg + { + let capture_bindings = thread_closure_captures( + &mut cloner, + target, + &mut new_decl, + &remapped_param, + captures, + ); + Some((closure_target, capture_bindings)) + } else { + None + }; + + // Recover the assigner from the cloner so all subsequent allocations + // flow through the shared pipeline assigner. + *assigner = cloner.into_assigner(); + + // Transform the body to replace callable param with the concrete callee. + let impl_clone = new_decl.implementation.clone(); + transform_callable_body( + target, + package_id, + &impl_clone, + &remapped_param, + &call_site.callable_arg, + assigner, + ); + + if let Some((closure_target, capture_bindings)) = closure_info { + rewrite_closure_target_call_args( + target, + &new_decl.implementation, + package_id, + closure_target, + &capture_bindings, + assigner, + ); + } + + // Remove the callable parameter from the input pattern and update types. + remove_callable_param(target, &mut new_decl, &remapped_param); + refresh_rewritten_value_types(target, &new_decl.implementation); + + // Insert the new item. + let new_item = Item { + id: new_item_id, + span: Span::default(), + parent: None, + doc: Rc::from(""), + attrs: Vec::new(), + visibility: Visibility::Internal, + kind: ItemKind::Callable(Box::new(new_decl)), + }; + target.items.insert(new_item_id, new_item); + + Some(new_item_id) +} + +/// Transforms all specialization bodies in a callable implementation, +/// replacing uses of the callable parameter with direct calls to the concrete +/// callee. +fn transform_callable_body( + package: &mut Package, + package_id: PackageId, + callable_impl: &CallableImpl, + param: &CallableParam, + concrete: &ConcreteCallable, + assigner: &mut Assigner, +) { + let mut alias_set = AliasSet::default(); + match callable_impl { + CallableImpl::Intrinsic => {} + CallableImpl::Spec(spec_impl) => { + transform_block( + package, + package_id, + spec_impl.body.block, + param, + concrete, + &mut alias_set, + assigner, + ); + if let Some(ref adj) = spec_impl.adj { + transform_block( + package, + package_id, + adj.block, + param, + concrete, + &mut alias_set, + assigner, + ); + } + if let Some(ref ctl) = spec_impl.ctl { + transform_block( + package, + package_id, + ctl.block, + param, + concrete, + &mut alias_set, + assigner, + ); + } + if let Some(ref ctl_adj) = spec_impl.ctl_adj { + transform_block( + package, + package_id, + ctl_adj.block, + param, + concrete, + &mut alias_set, + assigner, + ); + } + } + CallableImpl::SimulatableIntrinsic(spec_decl) => { + transform_block( + package, + package_id, + spec_decl.block, + param, + concrete, + &mut alias_set, + assigner, + ); + } + } +} + +/// Recursively walks a block, transforming call expressions that invoke the +/// callable parameter. +fn transform_block( + package: &mut Package, + package_id: PackageId, + block_id: qsc_fir::fir::BlockId, + param: &CallableParam, + concrete: &ConcreteCallable, + alias_set: &mut AliasSet, + assigner: &mut Assigner, +) { + let block = package + .blocks + .get(block_id) + .expect("block not found") + .clone(); + for &stmt_id in &block.stmts { + transform_stmt( + package, package_id, stmt_id, param, concrete, alias_set, assigner, + ); + } +} + +/// Walks a pattern tree, returning the `LocalVarId` bound at the given +/// tuple-field path when every intermediate node is a tuple pattern and the +/// leaf is a `Bind`. +fn find_bind_local_at_field_path( + package: &Package, + pat_id: PatId, + field_path: &[usize], +) -> Option { + let pat = package.get_pat(pat_id); + match field_path.split_first() { + None => match &pat.kind { + PatKind::Bind(ident) => Some(ident.id), + PatKind::Tuple(_) | PatKind::Discard => None, + }, + Some((index, tail)) => match &pat.kind { + PatKind::Tuple(sub_pats) => sub_pats + .get(*index) + .and_then(|sub_pat_id| find_bind_local_at_field_path(package, *sub_pat_id, tail)), + PatKind::Bind(_) | PatKind::Discard => None, + }, + } +} + +/// Rewrites one statement in a specialized callable body and updates the alias +/// set used to recognize callable-parameter projections. +/// +/// Before, destructuring locals in `stmt_id` may still hide the callable +/// parameter behind tuple-field aliases. After, any newly introduced aliases are +/// recorded in `alias_set` and all child expressions in the statement have been +/// visited for direct-call rewriting. +fn transform_stmt( + package: &mut Package, + package_id: PackageId, + stmt_id: qsc_fir::fir::StmtId, + param: &CallableParam, + concrete: &ConcreteCallable, + alias_set: &mut AliasSet, + assigner: &mut Assigner, +) { + let stmt = package.stmts.get(stmt_id).expect("stmt not found").clone(); + match &stmt.kind { + qsc_fir::fir::StmtKind::Expr(expr_id) | qsc_fir::fir::StmtKind::Semi(expr_id) => { + transform_expr( + package, package_id, *expr_id, param, concrete, alias_set, assigner, + ); + } + qsc_fir::fir::StmtKind::Local(_, pat_id, expr_id) => { + // Record aliases introduced by destructuring the tuple-valued + // parameter down to the callable leaf. + if !param.field_path.is_empty() { + let init_expr = package.exprs.get(*expr_id).expect("expr not found"); + if let ExprKind::Var(Res::Local(var), _) = &init_expr.kind { + if *var == param.param_var { + if let Some(alias_var) = + find_bind_local_at_field_path(package, *pat_id, ¶m.field_path) + { + alias_set.insert(alias_var); + } + } else if alias_set.contains(var) { + let pat = package.pats.get(*pat_id).expect("pat not found"); + if let PatKind::Bind(ident) = &pat.kind { + alias_set.insert(ident.id); + } + } + } + } + transform_expr( + package, package_id, *expr_id, param, concrete, alias_set, assigner, + ); + } + qsc_fir::fir::StmtKind::Item(_) => {} + } +} + +/// Rewrites an expression subtree in the cloned specialization so callable +/// parameter uses become concrete callees. +/// +/// Before, calls may still target `param.param_var`, a tuple-field projection of +/// it, or an alias introduced by destructuring. After, every matching callee is +/// rewritten in place to invoke `concrete`, while nested blocks and control-flow +/// expressions are recursively normalized to the same post-specialization shape. +#[allow(clippy::too_many_lines)] +#[allow(clippy::too_many_arguments)] +fn transform_expr( + package: &mut Package, + package_id: PackageId, + expr_id: ExprId, + param: &CallableParam, + concrete: &ConcreteCallable, + alias_set: &mut AliasSet, + assigner: &mut Assigner, +) { + let expr = package.exprs.get(expr_id).expect("expr not found").clone(); + match &expr.kind { + ExprKind::Call(callee_id, args_id) => { + let callee_id = *callee_id; + let args_id = *args_id; + + // Check if the callee is our callable parameter (possibly wrapped + // in functor applications). + let (base_id, body_functor) = peel_body_functors(package, callee_id); + let base_kind = package.get_expr(base_id).kind.clone(); + + let replaced = if let ExprKind::Var(Res::Local(var), _) = &base_kind + && *var == param.param_var + && param.field_path.is_empty() + { + // Single-level param: direct use as callee. + replace_callee( + package, + package_id, + callee_id, + body_functor, + concrete, + assigner, + ); + true + } else if !param.field_path.is_empty() + && expr_matches_param_field_path( + package, + base_id, + param.param_var, + ¶m.field_path, + ) + { + replace_callee( + package, + package_id, + callee_id, + body_functor, + concrete, + assigner, + ); + true + } else { + false + }; + + // Also check alias set for nested params. + let replaced = if replaced { + true + } else if let ExprKind::Var(Res::Local(var), _) = &base_kind + && alias_set.contains(var) + { + replace_callee( + package, + package_id, + callee_id, + body_functor, + concrete, + assigner, + ); + true + } else { + false + }; + + if !replaced { + transform_expr( + package, package_id, callee_id, param, concrete, alias_set, assigner, + ); + } + + // Recurse into the arguments. + transform_expr( + package, package_id, args_id, param, concrete, alias_set, assigner, + ); + } + ExprKind::Block(block_id) => { + transform_block( + package, package_id, *block_id, param, concrete, alias_set, assigner, + ); + } + ExprKind::If(cond, body, els) => { + transform_expr( + package, package_id, *cond, param, concrete, alias_set, assigner, + ); + transform_expr( + package, package_id, *body, param, concrete, alias_set, assigner, + ); + if let Some(els_id) = els { + transform_expr( + package, package_id, *els_id, param, concrete, alias_set, assigner, + ); + } + } + ExprKind::While(cond, block_id) => { + transform_expr( + package, package_id, *cond, param, concrete, alias_set, assigner, + ); + transform_block( + package, package_id, *block_id, param, concrete, alias_set, assigner, + ); + } + ExprKind::Tuple(exprs) | ExprKind::Array(exprs) | ExprKind::ArrayLit(exprs) => { + for &e in exprs { + transform_expr(package, package_id, e, param, concrete, alias_set, assigner); + } + } + ExprKind::Assign(lhs, rhs) + | ExprKind::AssignOp(_, lhs, rhs) + | ExprKind::BinOp(_, lhs, rhs) + | ExprKind::ArrayRepeat(lhs, rhs) + | ExprKind::Index(lhs, rhs) => { + transform_expr( + package, package_id, *lhs, param, concrete, alias_set, assigner, + ); + transform_expr( + package, package_id, *rhs, param, concrete, alias_set, assigner, + ); + } + ExprKind::AssignField(a, _, b) | ExprKind::UpdateField(a, _, b) => { + transform_expr( + package, package_id, *a, param, concrete, alias_set, assigner, + ); + transform_expr( + package, package_id, *b, param, concrete, alias_set, assigner, + ); + } + ExprKind::AssignIndex(a, b, c) | ExprKind::UpdateIndex(a, b, c) => { + transform_expr( + package, package_id, *a, param, concrete, alias_set, assigner, + ); + transform_expr( + package, package_id, *b, param, concrete, alias_set, assigner, + ); + transform_expr( + package, package_id, *c, param, concrete, alias_set, assigner, + ); + } + ExprKind::UnOp(_, inner) | ExprKind::Return(inner) | ExprKind::Fail(inner) => { + transform_expr( + package, package_id, *inner, param, concrete, alias_set, assigner, + ); + } + ExprKind::Field(inner_id, _) => { + // For nested callable params, check if this Field expression + // accesses the arrow element within the param variable. + if !param.field_path.is_empty() + && expr_matches_param_field_path( + package, + expr_id, + param.param_var, + ¶m.field_path, + ) + { + replace_callee( + package, + package_id, + expr_id, + FunctorApp::default(), + concrete, + assigner, + ); + return; + } + transform_expr( + package, package_id, *inner_id, param, concrete, alias_set, assigner, + ); + } + ExprKind::Range(a, b, c) => { + if let Some(a) = a { + transform_expr( + package, package_id, *a, param, concrete, alias_set, assigner, + ); + } + if let Some(b) = b { + transform_expr( + package, package_id, *b, param, concrete, alias_set, assigner, + ); + } + if let Some(c) = c { + transform_expr( + package, package_id, *c, param, concrete, alias_set, assigner, + ); + } + } + ExprKind::String(components) => { + for comp in components { + if let qsc_fir::fir::StringComponent::Expr(e) = comp { + transform_expr( + package, package_id, *e, param, concrete, alias_set, assigner, + ); + } + } + } + ExprKind::Struct(_, copy, fields) => { + if let Some(c) = copy { + transform_expr( + package, package_id, *c, param, concrete, alias_set, assigner, + ); + } + for f in fields { + transform_expr( + package, package_id, f.value, param, concrete, alias_set, assigner, + ); + } + } + // Substitute the callable parameter variable (or an alias from + // destructuring) at non-callee positions (e.g., when forwarded as an + // argument to an inner HOF). + ExprKind::Var(Res::Local(var), _) + if (*var == param.param_var && param.field_path.is_empty()) + || alias_set.contains(var) => + { + replace_callee( + package, + package_id, + expr_id, + FunctorApp::default(), + concrete, + assigner, + ); + } + // When a closure captures the callable parameter being specialized, + // propagate the specialization into the closure's target callable and + // remove the capture. + ExprKind::Closure(captures, target) => { + if let Some(capture_idx) = captures + .iter() + .position(|&c| c == param.param_var || alias_set.contains(&c)) + { + let target = *target; + transform_closure_param_capture( + package, + package_id, + expr_id, + target, + capture_idx, + param, + concrete, + assigner, + ); + } + } + // Terminals with no sub-expressions. + ExprKind::Hole | ExprKind::Lit(_) | ExprKind::Var(_, _) => {} + } +} + +/// Returns true when an expression is a field chain rooted at `param_var` +/// and its collected field path exactly matches `field_path`. +fn expr_matches_param_field_path( + package: &Package, + expr_id: ExprId, + param_var: LocalVarId, + field_path: &[usize], +) -> bool { + collect_field_path_from_param(package, expr_id, param_var) + .is_some_and(|path| path == field_path) +} + +/// Collects field indices from nested `Field(Path)` expressions rooted at `param_var`. +fn collect_field_path_from_param( + package: &Package, + expr_id: ExprId, + param_var: LocalVarId, +) -> Option> { + let expr = package.get_expr(expr_id); + match &expr.kind { + ExprKind::Var(Res::Local(var), _) if *var == param_var => Some(Vec::new()), + ExprKind::Field(inner_id, Field::Path(FieldPath { indices })) => { + let mut path = collect_field_path_from_param(package, *inner_id, param_var)?; + path.extend(indices.iter().copied()); + Some(path) + } + _ => None, + } +} + +/// Replaces the callee expression with a direct reference to the concrete +/// callable, applying the effective functor (composition of creation-site +/// and body-site functors). +/// +/// # Before +/// ```text +/// callee_expr = Var(Local(param_var)) : Arrow // indirect via callable parameter +/// ``` +/// # After +/// ```text +/// callee_expr = Ctl?(Adj?(Var(Item(concrete)))) : Arrow // direct, with functors +/// ``` +/// +/// # Mutations +/// - Overwrites `callee_expr_id`'s `ExprKind` and `Ty` in place. +/// - Allocates functor-wrapper `Expr` nodes through `assigner` when the +/// effective functor is non-trivial. +fn replace_callee( + package: &mut Package, + package_id: PackageId, + callee_expr_id: ExprId, + body_functor: FunctorApp, + concrete: &ConcreteCallable, + assigner: &mut Assigner, +) { + let (target_res, creation_functor) = match concrete { + ConcreteCallable::Global { item_id, functor } => (Res::Item(*item_id), *functor), + ConcreteCallable::Closure { + target, functor, .. + } => { + let item_id = ItemId { + package: package_id, + item: *target, + }; + (Res::Item(item_id), *functor) + } + ConcreteCallable::Dynamic => return, + }; + + let effective = compose_functors(&creation_functor, &body_functor); + + let callee_expr = package.exprs.get(callee_expr_id).expect("expr not found"); + let original_callee_ty = callee_expr.ty.clone(); + let callee_span = callee_expr.span; + let callee_ty = match concrete { + ConcreteCallable::Closure { target, .. } => build_direct_target_callee_ty( + package, + *target, + &original_callee_ty, + usize::from(effective.controlled), + ) + .unwrap_or_else(|| original_callee_ty.clone()), + ConcreteCallable::Global { .. } | ConcreteCallable::Dynamic => original_callee_ty.clone(), + }; + + let base_kind = match concrete { + ConcreteCallable::Closure { + target, captures, .. + } if captures.is_empty() => ExprKind::Closure(Vec::new(), *target), + _ => ExprKind::Var(target_res, Vec::new()), + }; + + if !effective.adjoint && effective.controlled == 0 { + // No functor wrapping — replace directly. + let expr = package + .exprs + .get_mut(callee_expr_id) + .expect("expr not found"); + expr.kind = base_kind; + expr.ty = callee_ty; + } else { + // Allocate fresh expressions for functor wrapper layers. + let mut current_id = assigner.next_expr(); + package.exprs.insert( + current_id, + Expr { + id: current_id, + span: callee_span, + ty: callee_ty.clone(), + kind: base_kind, + exec_graph_range: EMPTY_EXEC_RANGE, + }, + ); + + if effective.adjoint { + let adj_id = assigner.next_expr(); + package.exprs.insert( + adj_id, + Expr { + id: adj_id, + span: callee_span, + ty: callee_ty.clone(), + kind: ExprKind::UnOp(UnOp::Functor(Functor::Adj), current_id), + exec_graph_range: EMPTY_EXEC_RANGE, + }, + ); + current_id = adj_id; + } + + for _ in 0..effective.controlled { + let ctl_id = assigner.next_expr(); + package.exprs.insert( + ctl_id, + Expr { + id: ctl_id, + span: callee_span, + ty: callee_ty.clone(), + kind: ExprKind::UnOp(UnOp::Functor(Functor::Ctl), current_id), + exec_graph_range: EMPTY_EXEC_RANGE, + }, + ); + current_id = ctl_id; + } + + // Copy the outermost node's kind into the original callee expr. + let outermost_kind = package + .exprs + .get(current_id) + .expect("expr not found") + .kind + .clone(); + let expr = package + .exprs + .get_mut(callee_expr_id) + .expect("expr not found"); + expr.kind = outermost_kind; + expr.ty = callee_ty; + } +} + +/// Derives the arrow type of the direct-call target from the HOF's +/// indirect-call site arrow type, peeling `controlled_layers` to reach the +/// right nested input slot. +fn build_direct_target_callee_ty( + package: &Package, + target_item_id: LocalItemId, + callee_ty: &Ty, + controlled_layers: usize, +) -> Option { + let Ty::Arrow(arrow) = callee_ty else { + return None; + }; + + let ItemKind::Callable(decl) = &package.get_item(target_item_id).kind else { + return None; + }; + + let target_input = package.get_pat(decl.input).ty.clone(); + let new_input = + apply_target_input_at_control_path(&arrow.input, &target_input, controlled_layers); + + Some(Ty::Arrow(Box::new(Arrow { + kind: arrow.kind, + input: Box::new(new_input), + output: arrow.output.clone(), + functors: arrow.functors, + }))) +} + +/// Replaces the innermost input slot beneath `controlled_layers` nested +/// controlled-operation tuples with `target_input`, returning the rewritten +/// outer type. +/// +/// A copy of this helper also lives in +/// [`super::rewrite::apply_target_input_at_control_path`]; keep the two in +/// sync when changing controlled-layer handling. See the module-level note +/// in `rewrite.rs` for why both copies exist. +fn apply_target_input_at_control_path( + current_input: &Ty, + target_input: &Ty, + controlled_layers: usize, +) -> Ty { + if controlled_layers == 0 { + return target_input.clone(); + } + + match current_input { + Ty::Tuple(items) if items.len() > 1 => { + let mut new_items = items.clone(); + new_items[1] = apply_target_input_at_control_path( + &new_items[1], + target_input, + controlled_layers - 1, + ); + Ty::Tuple(new_items) + } + _ => target_input.clone(), + } +} + +/// When the HOF body contains a closure that captures the callable parameter +/// being specialized, we must propagate the concrete callable into the +/// closure's target callable and remove the capture so that the `param_var` +/// reference is eliminated. +/// +/// # Before +/// ```text +/// Closure([param_var, ...], target) // target body uses param_var +/// ``` +/// # After +/// ```text +/// Closure([...], target') // param_var capture removed; +/// // target body uses concrete callee directly +/// ``` +/// +/// # Mutations +/// - Transforms the closure target's body via [`transform_callable_body`]. +/// - Removes the capture from the target's input pattern via +/// [`remove_capture_from_closure_target`]. +/// - Removes the capture from the `Closure` expression's capture list. +#[allow(clippy::too_many_arguments)] +fn transform_closure_param_capture( + package: &mut Package, + package_id: PackageId, + closure_expr_id: ExprId, + closure_target: LocalItemId, + capture_idx: usize, + param: &CallableParam, + concrete: &ConcreteCallable, + assigner: &mut Assigner, +) { + // Step 1: Find the corresponding binding in the closure target's input pattern. + let target_item = package.items.get(closure_target); + let Some(Item { + kind: ItemKind::Callable(target_decl), + .. + }) = target_item + else { + return; + }; + let target_decl = target_decl.as_ref().clone(); + + let target_input_pat = package + .pats + .get(target_decl.input) + .expect("input pat not found") + .clone(); + + // The input pattern should be a Tuple with captures first, then lambda params. + let capture_param_var = match &target_input_pat.kind { + PatKind::Tuple(pats) => { + if capture_idx >= pats.len() { + return; + } + let capture_pat = package.pats.get(pats[capture_idx]).expect("pat not found"); + match &capture_pat.kind { + PatKind::Bind(ident) => ident.id, + _ => return, + } + } + PatKind::Bind(ident) if capture_idx == 0 => ident.id, + _ => return, + }; + + // Step 2: Create a synthetic CallableParam for the closure target's captured param. + let closure_param = CallableParam::new( + closure_target, + target_decl.input, + capture_idx, + Vec::new(), + capture_param_var, + param.param_ty.clone(), + ); + + // Step 3: Transform the target callable's body to replace uses of the + // captured param with the concrete callable. + transform_callable_body( + package, + package_id, + &target_decl.implementation, + &closure_param, + concrete, + assigner, + ); + + // Step 4: Remove the capture binding from the target callable's input. + remove_capture_from_closure_target(package, closure_target, capture_idx); + + // Step 5: Remove the capture from the Closure expression. + let closure_expr = package + .exprs + .get_mut(closure_expr_id) + .expect("closure expr not found"); + if let ExprKind::Closure(ref mut captures, _) = closure_expr.kind + && capture_idx < captures.len() + { + captures.remove(capture_idx); + } +} + +/// Removes the capture at `capture_idx` from the closure target callable's +/// input pattern tuple. +/// +/// # Before +/// ```text +/// input = (capture_0, capture_1, lambda_param) // capture_idx = 1 +/// ``` +/// # After +/// ```text +/// input = (capture_0, lambda_param) // capture_1 removed +/// ``` +/// +/// # Mutations +/// - Rewrites the input `Pat` node in place (or replaces `decl.input` when +/// flattening a single-element tuple). +fn remove_capture_from_closure_target( + package: &mut Package, + target_item_id: LocalItemId, + capture_idx: usize, +) { + let target_item = package.items.get(target_item_id); + let Some(Item { + kind: ItemKind::Callable(decl), + .. + }) = target_item + else { + return; + }; + let input_pat_id = decl.input; + + let input_pat = package + .pats + .get(input_pat_id) + .expect("pat not found") + .clone(); + match &input_pat.kind { + PatKind::Tuple(pats) => { + let new_pats: Vec = pats + .iter() + .enumerate() + .filter(|(i, _)| *i != capture_idx) + .map(|(_, &p)| p) + .collect(); + + let tys = match &input_pat.ty { + Ty::Tuple(tys) => tys.clone(), + _ => vec![input_pat.ty.clone(); pats.len()], + }; + let new_tys: Vec = tys + .iter() + .enumerate() + .filter(|(i, _)| *i != capture_idx) + .map(|(_, t)| t.clone()) + .collect(); + + if new_pats.len() == 1 { + // Flatten single-element tuple. + let item = package + .items + .get_mut(target_item_id) + .expect("item not found"); + if let ItemKind::Callable(ref mut decl) = item.kind { + decl.input = new_pats[0]; + } + } else { + let pat_mut = package.pats.get_mut(input_pat_id).expect("pat not found"); + pat_mut.kind = PatKind::Tuple(new_pats); + pat_mut.ty = if new_tys.is_empty() { + Ty::UNIT + } else { + Ty::Tuple(new_tys) + }; + } + } + PatKind::Bind(_) if capture_idx == 0 => { + // Only parameter is the capture — replace with unit. + let pat_mut = package.pats.get_mut(input_pat_id).expect("pat not found"); + pat_mut.kind = PatKind::Tuple(Vec::new()); + pat_mut.ty = Ty::UNIT; + } + _ => {} + } +} + +/// When the concrete callable is a closure, its captured variables must be +/// threaded as additional parameters to the specialized callable. +/// +/// # Before +/// ```text +/// input = (param_0, param_1) // original HOF input +/// ``` +/// # After +/// ```text +/// input = (param_0, param_1, __capture_0, ..., __capture_N) +/// ``` +/// +/// # Mutations +/// - Extends the input `Pat` tuple with new `Bind` patterns for each +/// capture, or wraps a scalar input in a tuple. +/// - Allocates new `Pat` and `LocalVarId` nodes through `cloner`. +fn thread_closure_captures( + cloner: &mut FirCloner, + package: &mut Package, + decl: &mut CallableDecl, + _param: &CallableParam, + captures: &[CapturedVar], +) -> Vec<(LocalVarId, Ty)> { + if captures.is_empty() { + return Vec::new(); + } + + // Allocate new bindings for each captured variable and build a remap. + let mut capture_bindings: Vec<(LocalVarId, Ty)> = Vec::with_capacity(captures.len()); + let mut new_pat_ids: Vec = Vec::new(); + let mut new_tys: Vec = Vec::new(); + + for (i, capture) in captures.iter().enumerate() { + let new_pat_id = cloner.alloc_pat(); + let new_local_var = cloner.alloc_local(capture.var); + capture_bindings.push((new_local_var, capture.ty.clone())); + + let name: Rc = Rc::from(format!("__capture_{i}")); + let new_pat = Pat { + id: new_pat_id, + span: Span::default(), + ty: capture.ty.clone(), + kind: PatKind::Bind(Ident { + id: new_local_var, + span: Span::default(), + name, + }), + }; + package.pats.insert(new_pat_id, new_pat); + new_pat_ids.push(new_pat_id); + new_tys.push(capture.ty.clone()); + } + + // Extend the input with capture patterns. + let input_pat = package + .pats + .get(decl.input) + .expect("input pat not found") + .clone(); + match &input_pat.kind { + PatKind::Tuple(_) => { + let input_pat_mut = package + .pats + .get_mut(decl.input) + .expect("input pat not found"); + if let PatKind::Tuple(ref mut pats) = input_pat_mut.kind { + pats.extend(new_pat_ids); + } + if let Ty::Tuple(ref mut tys) = input_pat_mut.ty { + tys.extend(new_tys); + } + } + PatKind::Bind(_) | PatKind::Discard => { + // Wrap in a tuple with the captures. + let old_pat_id = decl.input; + let tuple_pat_id = cloner.alloc_pat(); + let mut sub_pats = vec![old_pat_id]; + sub_pats.extend(new_pat_ids); + + let mut all_tys = vec![input_pat.ty.clone()]; + all_tys.extend(new_tys); + + let tuple_pat = Pat { + id: tuple_pat_id, + span: Span::default(), + ty: Ty::Tuple(all_tys), + kind: PatKind::Tuple(sub_pats), + }; + package.pats.insert(tuple_pat_id, tuple_pat); + decl.input = tuple_pat_id; + } + } + + capture_bindings +} + +/// Rewrites the call-argument expression for a closure target by splicing +/// the captured bindings into the appropriate slot of the call's argument +/// tuple. +/// +/// # Before +/// ```text +/// Call(Var(closure_target), original_args) +/// ``` +/// # After +/// ```text +/// Call(Var(closure_target), (__capture_0, ..., original_args)) +/// ``` +/// +/// The original args expression is preserved as a single element in the +/// new outer tuple, not flattened. +/// +/// # Mutations +/// - Delegates to [`rewrite_closure_target_call_args_in_block`] across +/// all specialization bodies. +fn rewrite_closure_target_call_args( + package: &mut Package, + callable_impl: &CallableImpl, + package_id: PackageId, + closure_target: LocalItemId, + capture_bindings: &[(LocalVarId, Ty)], + assigner: &mut Assigner, +) { + match callable_impl { + CallableImpl::Intrinsic => {} + CallableImpl::Spec(spec_impl) => { + rewrite_closure_target_call_args_in_block( + package, + spec_impl.body.block, + package_id, + closure_target, + capture_bindings, + assigner, + ); + if let Some(adj) = &spec_impl.adj { + rewrite_closure_target_call_args_in_block( + package, + adj.block, + package_id, + closure_target, + capture_bindings, + assigner, + ); + } + if let Some(ctl) = &spec_impl.ctl { + rewrite_closure_target_call_args_in_block( + package, + ctl.block, + package_id, + closure_target, + capture_bindings, + assigner, + ); + } + if let Some(ctl_adj) = &spec_impl.ctl_adj { + rewrite_closure_target_call_args_in_block( + package, + ctl_adj.block, + package_id, + closure_target, + capture_bindings, + assigner, + ); + } + } + CallableImpl::SimulatableIntrinsic(spec_decl) => { + rewrite_closure_target_call_args_in_block( + package, + spec_decl.block, + package_id, + closure_target, + capture_bindings, + assigner, + ); + } + } +} + +/// Walks a block after closure specialization and prepends captured locals to +/// every call that now targets the closure body directly. +/// +/// Before, calls to `closure_target` still rely on the closure value to carry +/// its captures implicitly. After, each matching call in `block_id` passes the +/// captured locals explicitly so the rewritten target signature is satisfied. +fn rewrite_closure_target_call_args_in_block( + package: &mut Package, + block_id: qsc_fir::fir::BlockId, + package_id: PackageId, + closure_target: LocalItemId, + capture_bindings: &[(LocalVarId, Ty)], + assigner: &mut Assigner, +) { + let block = package.get_block(block_id).clone(); + for stmt_id in block.stmts { + rewrite_closure_target_call_args_in_stmt( + package, + stmt_id, + package_id, + closure_target, + capture_bindings, + assigner, + ); + } +} + +/// Applies closure-capture threading to every expression nested under one +/// statement. +/// +/// Before, `stmt_id` may still contain calls whose argument tuple omits the +/// captures now required by `closure_target`. After, all expressions reachable +/// from the statement have been rewritten so those calls pass the captures +/// explicitly. +fn rewrite_closure_target_call_args_in_stmt( + package: &mut Package, + stmt_id: qsc_fir::fir::StmtId, + package_id: PackageId, + closure_target: LocalItemId, + capture_bindings: &[(LocalVarId, Ty)], + assigner: &mut Assigner, +) { + let stmt = package.get_stmt(stmt_id).clone(); + match stmt.kind { + qsc_fir::fir::StmtKind::Expr(expr_id) + | qsc_fir::fir::StmtKind::Semi(expr_id) + | qsc_fir::fir::StmtKind::Local(_, _, expr_id) => rewrite_closure_target_call_args_in_expr( + package, + expr_id, + package_id, + closure_target, + capture_bindings, + assigner, + ), + qsc_fir::fir::StmtKind::Item(_) => {} + } +} + +/// Rewrites an expression subtree so direct calls to a closure target receive +/// explicit capture operands. +/// +/// Before, the expression tree may still contain `Call`s whose callee resolves +/// to `closure_target` but whose args tuple omits the captures that were baked +/// into the original closure value. After, every such call prepends those +/// captures, matching the rewritten direct callable signature. +#[allow(clippy::too_many_lines)] +#[allow(clippy::too_many_arguments)] +fn rewrite_closure_target_call_args_in_expr( + package: &mut Package, + expr_id: ExprId, + package_id: PackageId, + closure_target: LocalItemId, + capture_bindings: &[(LocalVarId, Ty)], + assigner: &mut Assigner, +) { + let expr = package.get_expr(expr_id).clone(); + match expr.kind { + ExprKind::Call(callee_id, args_id) => { + rewrite_closure_target_call_args_in_expr( + package, + callee_id, + package_id, + closure_target, + capture_bindings, + assigner, + ); + rewrite_closure_target_call_args_in_expr( + package, + args_id, + package_id, + closure_target, + capture_bindings, + assigner, + ); + + let (base_id, outer_functor) = peel_body_functors(package, callee_id); + let base_expr = package.get_expr(base_id); + if matches!( + base_expr.kind, + ExprKind::Var( + Res::Item(ItemId { + package: callee_package, + item: callee_item, + }), + _ + ) if callee_package == package_id && callee_item == closure_target + ) { + prepend_capture_args_to_call( + package, + args_id, + capture_bindings, + usize::from(outer_functor.controlled), + assigner, + ); + } + } + ExprKind::Block(block_id) => rewrite_closure_target_call_args_in_block( + package, + block_id, + package_id, + closure_target, + capture_bindings, + assigner, + ), + ExprKind::If(cond, body, otherwise) => { + rewrite_closure_target_call_args_in_expr( + package, + cond, + package_id, + closure_target, + capture_bindings, + assigner, + ); + rewrite_closure_target_call_args_in_expr( + package, + body, + package_id, + closure_target, + capture_bindings, + assigner, + ); + if let Some(otherwise) = otherwise { + rewrite_closure_target_call_args_in_expr( + package, + otherwise, + package_id, + closure_target, + capture_bindings, + assigner, + ); + } + } + ExprKind::While(cond, block_id) => { + rewrite_closure_target_call_args_in_expr( + package, + cond, + package_id, + closure_target, + capture_bindings, + assigner, + ); + rewrite_closure_target_call_args_in_block( + package, + block_id, + package_id, + closure_target, + capture_bindings, + assigner, + ); + } + ExprKind::Tuple(exprs) | ExprKind::Array(exprs) | ExprKind::ArrayLit(exprs) => { + for expr_id in exprs { + rewrite_closure_target_call_args_in_expr( + package, + expr_id, + package_id, + closure_target, + capture_bindings, + assigner, + ); + } + } + ExprKind::Assign(lhs, rhs) + | ExprKind::AssignOp(_, lhs, rhs) + | ExprKind::BinOp(_, lhs, rhs) + | ExprKind::ArrayRepeat(lhs, rhs) + | ExprKind::Index(lhs, rhs) + | ExprKind::AssignField(lhs, _, rhs) + | ExprKind::UpdateField(lhs, _, rhs) => { + rewrite_closure_target_call_args_in_expr( + package, + lhs, + package_id, + closure_target, + capture_bindings, + assigner, + ); + rewrite_closure_target_call_args_in_expr( + package, + rhs, + package_id, + closure_target, + capture_bindings, + assigner, + ); + } + ExprKind::AssignIndex(a, b, c) | ExprKind::UpdateIndex(a, b, c) => { + rewrite_closure_target_call_args_in_expr( + package, + a, + package_id, + closure_target, + capture_bindings, + assigner, + ); + rewrite_closure_target_call_args_in_expr( + package, + b, + package_id, + closure_target, + capture_bindings, + assigner, + ); + rewrite_closure_target_call_args_in_expr( + package, + c, + package_id, + closure_target, + capture_bindings, + assigner, + ); + } + ExprKind::UnOp(_, inner) + | ExprKind::Return(inner) + | ExprKind::Fail(inner) + | ExprKind::Field(inner, _) => rewrite_closure_target_call_args_in_expr( + package, + inner, + package_id, + closure_target, + capture_bindings, + assigner, + ), + ExprKind::Range(start, step, end) => { + if let Some(start) = start { + rewrite_closure_target_call_args_in_expr( + package, + start, + package_id, + closure_target, + capture_bindings, + assigner, + ); + } + if let Some(step) = step { + rewrite_closure_target_call_args_in_expr( + package, + step, + package_id, + closure_target, + capture_bindings, + assigner, + ); + } + if let Some(end) = end { + rewrite_closure_target_call_args_in_expr( + package, + end, + package_id, + closure_target, + capture_bindings, + assigner, + ); + } + } + ExprKind::String(components) => { + for component in components { + if let qsc_fir::fir::StringComponent::Expr(expr_id) = component { + rewrite_closure_target_call_args_in_expr( + package, + expr_id, + package_id, + closure_target, + capture_bindings, + assigner, + ); + } + } + } + ExprKind::Struct(_, copy, fields) => { + if let Some(copy) = copy { + rewrite_closure_target_call_args_in_expr( + package, + copy, + package_id, + closure_target, + capture_bindings, + assigner, + ); + } + for field in fields { + rewrite_closure_target_call_args_in_expr( + package, + field.value, + package_id, + closure_target, + capture_bindings, + assigner, + ); + } + } + ExprKind::Closure(_, _) | ExprKind::Hole | ExprKind::Lit(_) | ExprKind::Var(_, _) => {} + } +} + +/// Prepends captured variables as additional arguments ahead of the +/// existing call-site argument tuple (respecting controlled-layer nesting). +/// +/// # Before +/// ```text +/// args = (original_args) // or (ctrl_qubits, (original_args)) +/// ``` +/// # After +/// ```text +/// args = (__capture_0, ..., __capture_N, original_args) +/// ``` +/// +/// # Mutations +/// - Rewrites `args_id`'s `ExprKind` and `Ty` in place to a `Tuple` +/// containing capture `Var` expressions followed by the preserved args. +/// - Allocates capture `Var` `Expr` nodes through `assigner`. +fn prepend_capture_args_to_call( + package: &mut Package, + args_id: ExprId, + capture_bindings: &[(LocalVarId, Ty)], + controlled_layers: usize, + assigner: &mut Assigner, +) { + if controlled_layers > 0 { + let inner_id = match package.get_expr(args_id).kind { + ExprKind::Tuple(ref elements) if elements.len() > 1 => elements[1], + _ => return, + }; + prepend_capture_args_to_call( + package, + inner_id, + capture_bindings, + controlled_layers - 1, + assigner, + ); + let inner_ty = package.get_expr(inner_id).ty.clone(); + let args_expr = package.exprs.get_mut(args_id).expect("args expr not found"); + if let Ty::Tuple(ref mut tys) = args_expr.ty + && tys.len() > 1 + { + tys[1] = inner_ty; + } + return; + } + + let original_args = package.get_expr(args_id).clone(); + let preserved_args_id = assigner.next_expr(); + package.exprs.insert( + preserved_args_id, + Expr { + id: preserved_args_id, + span: original_args.span, + ty: original_args.ty.clone(), + kind: original_args.kind, + exec_graph_range: EMPTY_EXEC_RANGE, + }, + ); + + let mut tuple_items = Vec::with_capacity(capture_bindings.len() + 1); + let mut tuple_tys = Vec::with_capacity(capture_bindings.len() + 1); + for (capture_var, capture_ty) in capture_bindings { + let capture_expr_id = assigner.next_expr(); + package.exprs.insert( + capture_expr_id, + Expr { + id: capture_expr_id, + span: original_args.span, + ty: capture_ty.clone(), + kind: ExprKind::Var(Res::Local(*capture_var), Vec::new()), + exec_graph_range: EMPTY_EXEC_RANGE, + }, + ); + tuple_items.push(capture_expr_id); + tuple_tys.push(capture_ty.clone()); + } + tuple_items.push(preserved_args_id); + tuple_tys.push(original_args.ty); + + let args_expr = package.exprs.get_mut(args_id).expect("args expr not found"); + args_expr.kind = ExprKind::Tuple(tuple_items); + args_expr.ty = Ty::Tuple(tuple_tys); +} + +/// Removes the callable parameter from the specialized callable's input +/// pattern and updates the corresponding types. +/// +/// # Before +/// ```text +/// input = (param_0, callable_param, param_2) // top_level_param = 1 +/// ``` +/// # After +/// ```text +/// input = (param_0, param_2) // callable_param removed +/// ``` +/// +/// # Mutations +/// - Rewrites the input `Pat` node's `kind` and `ty` in place. +/// - Flattens single-element tuples. +/// - For nested params, delegates to [`remove_nested_callable_param`]. +fn remove_callable_param(package: &mut Package, decl: &mut CallableDecl, param: &CallableParam) { + if !param.field_path.is_empty() { + remove_nested_callable_param(package, decl, param); + return; + } + + let input_pat = package + .pats + .get(decl.input) + .expect("input pat not found") + .clone(); + + match &input_pat.kind { + PatKind::Tuple(pats) => { + let mut new_pats: Vec = Vec::new(); + let mut new_tys: Vec = Vec::new(); + + let tys = match &input_pat.ty { + Ty::Tuple(tys) => tys.clone(), + _ => vec![input_pat.ty.clone(); pats.len()], + }; + + for (i, (&pat_id, ty)) in pats.iter().zip(tys.iter()).enumerate() { + if i != param.top_level_param { + new_pats.push(pat_id); + new_tys.push(ty.clone()); + } + } + + if new_pats.len() == 1 { + // Flatten single-element tuple to the single pattern. + decl.input = new_pats[0]; + } else { + let input_pat_mut = package.pats.get_mut(decl.input).expect("pat not found"); + input_pat_mut.kind = PatKind::Tuple(new_pats); + input_pat_mut.ty = Ty::Tuple(new_tys); + } + } + PatKind::Bind(_) => { + // The only parameter IS the callable param — replace with unit. + let input_pat_mut = package.pats.get_mut(decl.input).expect("pat not found"); + input_pat_mut.kind = PatKind::Tuple(Vec::new()); + input_pat_mut.ty = Ty::UNIT; + } + PatKind::Discard => {} + } +} + +/// Removes a nested callable parameter from the specialized callable's input +/// pattern by navigating into the tuple type at the outer position and removing +/// the arrow element at the inner position. Also rewrites any destructuring +/// patterns in the body that bind the removed element. +/// +/// # Before +/// ```text +/// input = (outer: (a, callable, b)) // field_path = [1] +/// ``` +/// # After +/// ```text +/// input = (outer: (a, b)) // nested callable removed +/// ``` +/// +/// # Mutations +/// - Rewrites `Pat.ty` for the sub-pattern and outer tuple in place. +/// - Rewrites destructuring patterns in the body via +/// [`rewrite_destructuring_pat_in_block`]. +fn remove_nested_callable_param( + package: &mut Package, + decl: &mut CallableDecl, + param: &CallableParam, +) { + let input_pat = package + .pats + .get(decl.input) + .expect("input pat not found") + .clone(); + + let outer_idx = param.top_level_param; + let inner_path = param.field_path.as_slice(); + + match &input_pat.kind { + PatKind::Tuple(pats) => { + // Navigate to the sub-pattern at outer_idx and modify its type. + let sub_pat_id = pats[outer_idx]; + let sub_pat = package.pats.get(sub_pat_id).expect("pat not found").clone(); + let new_ty = remove_ty_at_nested_path(package, &sub_pat.ty, inner_path); + let sub_pat_mut = package.pats.get_mut(sub_pat_id).expect("pat not found"); + sub_pat_mut.ty = new_ty.clone(); + + // Update the outer tuple's type to reflect the changed sub-parameter. + let input_pat_mut = package.pats.get_mut(decl.input).expect("pat not found"); + if let Ty::Tuple(ref mut tys) = input_pat_mut.ty { + tys[outer_idx] = new_ty; + } + } + PatKind::Bind(_) => { + // Single param that is a tuple type — modify the type directly. + let new_ty = remove_ty_at_nested_path(package, &input_pat.ty, inner_path); + let input_pat_mut = package.pats.get_mut(decl.input).expect("pat not found"); + input_pat_mut.ty = new_ty; + } + PatKind::Discard => {} + } + + // Rewrite destructuring patterns in the body that bind param_var's tuple. + if !inner_path.is_empty() { + if let CallableImpl::Spec(spec_impl) = &decl.implementation { + rewrite_destructuring_pat_in_block( + package, + spec_impl.body.block, + param.param_var, + inner_path, + ); + if let Some(ref adj) = spec_impl.adj { + rewrite_destructuring_pat_in_block(package, adj.block, param.param_var, inner_path); + } + if let Some(ref ctl) = spec_impl.ctl { + rewrite_destructuring_pat_in_block(package, ctl.block, param.param_var, inner_path); + } + if let Some(ref ctl_adj) = spec_impl.ctl_adj { + rewrite_destructuring_pat_in_block( + package, + ctl_adj.block, + param.param_var, + inner_path, + ); + } + } else if let CallableImpl::SimulatableIntrinsic(spec_decl) = &decl.implementation { + rewrite_destructuring_pat_in_block( + package, + spec_decl.block, + param.param_var, + inner_path, + ); + } + } +} + +/// Walks a block and rewrites any destructuring `let` statement whose init +/// expression is `Var(Local(param_var))` by removing the sub-pattern at +/// `inner_path` from the tuple pattern. +/// +/// # Before +/// ```text +/// let (a, callable, b) = param_var; // inner_path = [1] +/// ``` +/// # After +/// ```text +/// let (a, b) = param_var; // callable sub-pattern removed +/// ``` +/// +/// # Mutations +/// - Rewrites `Pat.kind` and `Pat.ty` via [`remove_pat_at_field_path`]. +/// - Updates the init expression's type to match the rewritten pattern. +fn rewrite_destructuring_pat_in_block( + package: &mut Package, + block_id: qsc_fir::fir::BlockId, + param_var: LocalVarId, + inner_path: &[usize], +) { + let block = package + .blocks + .get(block_id) + .expect("block not found") + .clone(); + for &stmt_id in &block.stmts { + let stmt = package.stmts.get(stmt_id).expect("stmt not found").clone(); + if let qsc_fir::fir::StmtKind::Local(_, pat_id, expr_id) = &stmt.kind { + let rewrites_param_var = { + let init_expr = package.exprs.get(*expr_id).expect("expr not found"); + matches!(&init_expr.kind, ExprKind::Var(Res::Local(var), _) if *var == param_var) + }; + if rewrites_param_var && remove_pat_at_field_path(package, *pat_id, inner_path) { + let new_init_ty = package.pats.get(*pat_id).expect("pat not found").ty.clone(); + let init_mut = package.exprs.get_mut(*expr_id).expect("expr not found"); + init_mut.ty = new_init_ty; + } + } + } +} + +/// Removes the sub-pattern at `field_path` from a tuple pattern structure, +/// rewriting the outer pattern type so parameter removal stays type- +/// consistent. +/// +/// # Before +/// ```text +/// Pat::Tuple([p0, p1, p2]) // field_path = [1] +/// ``` +/// # After +/// ```text +/// Pat::Tuple([p0, p2]) // p1 removed, ty updated +/// ``` +/// +/// # Mutations +/// - Rewrites `Pat.kind` and `Pat.ty` in place. +/// - Flattens single-element tuples. +fn remove_pat_at_field_path(package: &mut Package, pat_id: PatId, field_path: &[usize]) -> bool { + let Some((index, tail)) = field_path.split_first() else { + return false; + }; + + let pat = package.pats.get(pat_id).expect("pat not found").clone(); + let PatKind::Tuple(sub_pats) = &pat.kind else { + return false; + }; + if *index >= sub_pats.len() { + return false; + } + + if tail.is_empty() { + let remaining_pats: Vec = sub_pats + .iter() + .enumerate() + .filter(|(i, _)| *i != *index) + .map(|(_, &sub_pat_id)| sub_pat_id) + .collect(); + let (new_kind, new_ty) = flattened_tuple_pat(package, &remaining_pats); + let pat_mut = package.pats.get_mut(pat_id).expect("pat not found"); + pat_mut.kind = new_kind; + pat_mut.ty = new_ty; + return true; + } + + let child_pat_id = sub_pats[*index]; + if !remove_pat_at_field_path(package, child_pat_id, tail) { + return false; + } + + let new_ty = Ty::Tuple( + sub_pats + .iter() + .map(|sub_pat_id| package.get_pat(*sub_pat_id).ty.clone()) + .collect(), + ); + let pat_mut = package.pats.get_mut(pat_id).expect("pat not found"); + pat_mut.ty = new_ty; + true +} + +/// Flattens a single-element tuple pattern to its contained pattern (so a +/// one-element tuple never survives pattern removal), returning the +/// resulting `(PatKind, Ty)` for the enclosing pattern slot. +fn flattened_tuple_pat(package: &Package, sub_pats: &[PatId]) -> (PatKind, Ty) { + match sub_pats { + [] => (PatKind::Tuple(Vec::new()), Ty::UNIT), + [only_pat_id] => { + let only_pat = package.get_pat(*only_pat_id); + (only_pat.kind.clone(), only_pat.ty.clone()) + } + _ => ( + PatKind::Tuple(sub_pats.to_vec()), + Ty::Tuple( + sub_pats + .iter() + .map(|pat_id| package.get_pat(*pat_id).ty.clone()) + .collect(), + ), + ), + } +} + +/// Removes the element at `path` from a nested tuple type structure. +/// For single-element paths, removes the element at that index from the tuple. +/// For multi-element paths, navigates into the tuple and recursively removes. +fn remove_ty_at_nested_path(package: &Package, ty: &Ty, path: &[usize]) -> Ty { + if path.is_empty() { + return Ty::UNIT; + } + let ty = resolve_udt_ty(package, ty); + if let Ty::Tuple(tys) = ty { + if path.len() == 1 { + let remaining: Vec = tys + .iter() + .enumerate() + .filter(|(i, _)| *i != path[0]) + .map(|(_, t)| t.clone()) + .collect(); + if remaining.is_empty() { + Ty::UNIT + } else if remaining.len() == 1 { + remaining.into_iter().next().expect("single element") + } else { + Ty::Tuple(remaining) + } + } else { + let mut new_tys = tys.clone(); + new_tys[path[0]] = remove_ty_at_nested_path(package, &tys[path[0]], &path[1..]); + Ty::Tuple(new_tys) + } + } else { + Ty::UNIT + } +} + +/// Expands UDT wrappers to the tuple/array/arrow structure that defunctionalization tracks. +/// +/// `CallableParam::field_path` is recorded against the pure structural shape of a parameter, +/// but specialization removes the callable parameter before UDT erasure has necessarily run. +/// When the input pattern still has a `Ty::Udt`, `remove_ty_at_nested_path` needs the same +/// structural view that analysis used so a path like `cfg::Inner::Op` can remove the arrow +/// field from the specialized callable's input type. Non-UDT leaves are preserved, and nested +/// tuples, arrays, and arrows are rebuilt with any UDTs inside them expanded as well. +fn resolve_udt_ty(package: &Package, ty: &Ty) -> Ty { + match ty { + Ty::Udt(Res::Item(item_id)) => { + let Some(item) = package.items.get(item_id.item) else { + return ty.clone(); + }; + let ItemKind::Ty(_, udt) = &item.kind else { + return ty.clone(); + }; + resolve_udt_ty(package, &udt.get_pure_ty()) + } + Ty::Tuple(elems) => Ty::Tuple( + elems + .iter() + .map(|elem| resolve_udt_ty(package, elem)) + .collect(), + ), + Ty::Array(elem) => Ty::Array(Box::new(resolve_udt_ty(package, elem))), + Ty::Arrow(arrow) => Ty::Arrow(Box::new(qsc_fir::ty::Arrow { + kind: arrow.kind, + input: Box::new(resolve_udt_ty(package, &arrow.input)), + output: Box::new(resolve_udt_ty(package, &arrow.output)), + functors: arrow.functors, + })), + _ => ty.clone(), + } +} + +/// Builds a standalone `Package` holding every node reachable from a +/// callable body so the cloner can read from a disjoint source while the +/// target package is mutated. +fn extract_callable_body(source_pkg: &Package, decl: &CallableDecl) -> Package { + let mut body_pkg = Package::default(); + + extract_pat(source_pkg, decl.input, &mut body_pkg); + + match &decl.implementation { + CallableImpl::Intrinsic => {} + CallableImpl::Spec(spec_impl) => { + extract_spec_decl_body(source_pkg, &spec_impl.body, &mut body_pkg); + for spec in functored_specs(spec_impl) { + extract_spec_decl_body(source_pkg, spec, &mut body_pkg); + } + } + CallableImpl::SimulatableIntrinsic(spec) => { + extract_spec_decl_body(source_pkg, spec, &mut body_pkg); + } + } + + body_pkg +} + +/// Copies a `SpecDecl`'s input pattern and block into the extraction +/// target package. +fn extract_spec_decl_body(source: &Package, spec: &qsc_fir::fir::SpecDecl, target: &mut Package) { + if let Some(pat_id) = spec.input { + extract_pat(source, pat_id, target); + } + extract_block(source, spec.block, target); +} + +/// Recursively copies a block and every statement it references into the +/// extraction target. +fn extract_block(source: &Package, block_id: qsc_fir::fir::BlockId, target: &mut Package) { + if target.blocks.contains_key(block_id) { + return; + } + let block = source.get_block(block_id); + target.blocks.insert(block_id, block.clone()); + for &stmt_id in &block.stmts { + extract_stmt(source, stmt_id, target); + } +} + +/// Recursively copies a statement and its referenced patterns, expressions, +/// or items into the extraction target. +fn extract_stmt(source: &Package, stmt_id: qsc_fir::fir::StmtId, target: &mut Package) { + if target.stmts.contains_key(stmt_id) { + return; + } + let stmt = source.get_stmt(stmt_id); + target.stmts.insert(stmt_id, stmt.clone()); + match &stmt.kind { + qsc_fir::fir::StmtKind::Expr(e) | qsc_fir::fir::StmtKind::Semi(e) => { + extract_expr(source, *e, target); + } + qsc_fir::fir::StmtKind::Local(_, pat, expr) => { + extract_pat(source, *pat, target); + extract_expr(source, *expr, target); + } + qsc_fir::fir::StmtKind::Item(item_id) => { + extract_item(source, *item_id, target); + } + } +} + +#[allow(clippy::too_many_lines)] +/// Recursively copies an expression and its transitive references into the +/// extraction target. +/// +/// NOTE: This is intentionally a separate implementation from the nearly +/// identical `extract_expr` in `monomorphize.rs`. The key difference is the +/// `ExprKind::Closure` arm: defunctionalize treats it as a leaf because +/// lambda-lifted items already live at package level and the +/// [`FirCloner`](crate::cloner::FirCloner) resolves them via its fallback +/// path (keeping the original `LocalItemId` in the target package). +/// Defunctionalize does not perform type substitution on cloned bodies, so +/// duplicating the lambda item would be wasteful. +/// +/// However, `StmtKind::Item` (named nested functions declared inside the +/// HOF body) MUST be followed here +fn extract_expr(source: &Package, expr_id: ExprId, target: &mut Package) { + if target.exprs.contains_key(expr_id) { + return; + } + let expr = source.get_expr(expr_id); + target.exprs.insert(expr_id, expr.clone()); + match &expr.kind { + ExprKind::Array(es) | ExprKind::ArrayLit(es) | ExprKind::Tuple(es) => { + for &e in es { + extract_expr(source, e, target); + } + } + ExprKind::ArrayRepeat(a, b) + | ExprKind::Assign(a, b) + | ExprKind::AssignOp(_, a, b) + | ExprKind::BinOp(_, a, b) + | ExprKind::Call(a, b) + | ExprKind::Index(a, b) + | ExprKind::AssignField(a, _, b) + | ExprKind::UpdateField(a, _, b) => { + extract_expr(source, *a, target); + extract_expr(source, *b, target); + } + ExprKind::AssignIndex(a, b, c) | ExprKind::UpdateIndex(a, b, c) => { + extract_expr(source, *a, target); + extract_expr(source, *b, target); + extract_expr(source, *c, target); + } + ExprKind::Block(block_id) => extract_block(source, *block_id, target), + ExprKind::Fail(e) | ExprKind::Field(e, _) | ExprKind::Return(e) | ExprKind::UnOp(_, e) => { + extract_expr(source, *e, target); + } + ExprKind::If(cond, body, otherwise) => { + extract_expr(source, *cond, target); + extract_expr(source, *body, target); + if let Some(e) = otherwise { + extract_expr(source, *e, target); + } + } + ExprKind::Range(s, st, e) => { + for x in [s, st, e].into_iter().flatten() { + extract_expr(source, *x, target); + } + } + ExprKind::Struct(_, copy, fields) => { + if let Some(c) = copy { + extract_expr(source, *c, target); + } + for fa in fields { + extract_expr(source, fa.value, target); + } + } + ExprKind::String(components) => { + for c in components { + if let qsc_fir::fir::StringComponent::Expr(e) = c { + extract_expr(source, *e, target); + } + } + } + ExprKind::While(cond, block) => { + extract_expr(source, *cond, target); + extract_block(source, *block, target); + } + ExprKind::Closure(_, _) | ExprKind::Hole | ExprKind::Lit(_) | ExprKind::Var(_, _) => {} + } +} + +/// Recursively copies a nested item (named function declared inside a block) +/// and its callable body into the extraction target so that +/// [`FirCloner::clone_nested_item`](crate::cloner::FirCloner) can find it +/// during specialization. +fn extract_item(source: &Package, item_id: LocalItemId, target: &mut Package) { + if target.items.contains_key(item_id) { + return; + } + let item = source.get_item(item_id); + target.items.insert(item_id, item.clone()); + if let ItemKind::Callable(decl) = &item.kind { + extract_pat(source, decl.input, target); + match &decl.implementation { + CallableImpl::Intrinsic => {} + CallableImpl::Spec(spec_impl) => { + extract_spec_decl_body(source, &spec_impl.body, target); + for spec in functored_specs(spec_impl) { + extract_spec_decl_body(source, spec, target); + } + } + CallableImpl::SimulatableIntrinsic(spec) => { + extract_spec_decl_body(source, spec, target); + } + } + } +} + +/// Recursively copies a pattern and its sub-patterns into the extraction +/// target. +fn extract_pat(source: &Package, pat_id: PatId, target: &mut Package) { + if target.pats.contains_key(pat_id) { + return; + } + let pat = source.get_pat(pat_id); + target.pats.insert(pat_id, pat.clone()); + if let PatKind::Tuple(sub_pats) = &pat.kind { + for &p in sub_pats { + extract_pat(source, p, target); + } + } +} diff --git a/source/compiler/qsc_fir_transforms/src/defunctionalize/tests.rs b/source/compiler/qsc_fir_transforms/src/defunctionalize/tests.rs new file mode 100644 index 0000000000..098fad5e20 --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/defunctionalize/tests.rs @@ -0,0 +1,685 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! Tests for the defunctionalization pass. + +use std::any::Any; + +use expect_test::{Expect, expect}; +use qsc_data_structures::target::TargetCapabilityFlags; +use qsc_fir::fir::{self, ItemId, ItemKind, LocalItemId, PackageLookup, PackageStoreLookup}; + +use super::analysis as defunc_analysis; +use super::defunctionalize; +use super::types::{ + CallableParam, CalleeLattice, ConcreteCallable, ConcreteCallableKey, SpecKey, compose_functors, +}; +use crate::fir_builder::reachable_local_callables; +use crate::reachability::collect_reachable_from_entry; +use crate::test_utils::{ + compile_to_monomorphized_fir, compile_to_monomorphized_fir_with_capabilities, +}; +use crate::walk_utils::collect_expr_ids_in_entry_and_local_callables; +use crate::{invariants as fir_invariants, invariants::InvariantLevel}; +use qsc_data_structures::functors::FunctorApp; + +mod analysis; +mod cross_package; +mod fixpoint; +mod invariants; +mod prepass; +mod specialization; + +fn adaptive_qirgen_capabilities() -> TargetCapabilityFlags { + TargetCapabilityFlags::Adaptive + | TargetCapabilityFlags::IntegerComputations + | TargetCapabilityFlags::FloatingPointComputations +} + +fn format_defunctionalization_errors(errors: &[super::Error]) -> String { + if errors.is_empty() { + "(no error)".to_string() + } else { + errors + .iter() + .map(ToString::to_string) + .collect::>() + .join("\n") + } +} + +fn assert_no_defunctionalization_errors(context: &str, errors: &[super::Error]) { + assert!( + errors.is_empty(), + "{context} produced errors:\n{}", + format_defunctionalization_errors(errors) + ); +} + +fn panic_message(panic: Box) -> String { + match panic.downcast::() { + Ok(message) => *message, + Err(panic) => match panic.downcast::<&str>() { + Ok(message) => (*message).to_string(), + Err(_) => "(non-string panic payload)".to_string(), + }, + } +} + +/// Compiles Q# source, runs defunctionalization, and snapshots the reachable +/// callable names and their input pattern types from the user package. +fn check(source: &str, expect: &Expect) { + let (fir_store, fir_pkg_id) = compile_and_defunctionalize(source); + let package = fir_store.get(fir_pkg_id); + let reachable = collect_reachable_from_entry(&fir_store, fir_pkg_id); + + let mut lines: Vec = Vec::new(); + for store_id in &reachable { + if store_id.package != fir_pkg_id { + continue; + } + let item = package.get_item(store_id.item); + if let ItemKind::Callable(decl) = &item.kind { + let pat = package.get_pat(decl.input); + lines.push(format!("{}: input_ty={}", decl.name.name, pat.ty)); + } + } + lines.sort(); + expect.assert_eq(&lines.join("\n")); +} + +fn compile_and_defunctionalize(source: &str) -> (fir::PackageStore, fir::PackageId) { + let (mut fir_store, fir_pkg_id) = compile_to_monomorphized_fir(source); + let mut assigner = qsc_fir::assigner::Assigner::from_package(fir_store.get(fir_pkg_id)); + let errors = defunctionalize(&mut fir_store, fir_pkg_id, &mut assigner); + assert_no_defunctionalization_errors("defunctionalization", &errors); + (fir_store, fir_pkg_id) +} + +fn callable_decl<'a>(package: &'a fir::Package, callable_name: &str) -> &'a fir::CallableDecl { + package + .items + .values() + .find_map(|item| match &item.kind { + ItemKind::Callable(decl) if decl.name.name.as_ref() == callable_name => { + Some(decl.as_ref()) + } + _ => None, + }) + .unwrap_or_else(|| panic!("callable '{callable_name}' not found")) +} + +fn call_arg_tuple_lengths_after_defunc(source: &str, callee_name: &str) -> Vec { + let (fir_store, fir_pkg_id) = compile_and_defunctionalize(source); + let package = fir_store.get(fir_pkg_id); + let mut lengths = Vec::new(); + for expr in package.exprs.values() { + let fir::ExprKind::Call(callee_id, args_id) = &expr.kind else { + continue; + }; + let callee = package.get_expr(*callee_id); + let fir::ExprKind::Var(fir::Res::Item(item_id), _) = &callee.kind else { + continue; + }; + if resolve_item_name(&fir_store, item_id) != callee_name { + continue; + } + let args = package.get_expr(*args_id); + let len = match &args.kind { + fir::ExprKind::Tuple(elements) => elements.len(), + _ => 1, + }; + lengths.push(len); + } + lengths.sort_unstable(); + lengths +} + +fn callable_call_targets_after_defunc(source: &str, callable_name: &str) -> Vec { + let (fir_store, fir_pkg_id) = compile_and_defunctionalize(source); + let package = fir_store.get(fir_pkg_id); + let decl = callable_decl(package, callable_name); + let mut targets = Vec::new(); + crate::walk_utils::for_each_expr_in_callable_impl( + package, + &decl.implementation, + &mut |_expr_id, expr| { + if let fir::ExprKind::Call(callee_id, _) = &expr.kind + && let Some(target) = call_target_name(&fir_store, package, *callee_id) + { + targets.push(target); + } + }, + ); + targets.sort(); + targets +} + +fn call_target_name( + store: &fir::PackageStore, + package: &fir::Package, + expr_id: fir::ExprId, +) -> Option { + let expr = package.get_expr(expr_id); + match &expr.kind { + fir::ExprKind::Var(fir::Res::Item(item_id), _) => Some(resolve_item_name(store, item_id)), + fir::ExprKind::UnOp(fir::UnOp::Functor(fir::Functor::Adj), inner) => { + call_target_name(store, package, *inner).map(|name| format!("Adjoint {name}")) + } + fir::ExprKind::UnOp(fir::UnOp::Functor(fir::Functor::Ctl), inner) => { + call_target_name(store, package, *inner).map(|name| format!("Controlled {name}")) + } + _ => None, + } +} + +/// Resolves an `ItemId` to its callable name, falling back to the raw display. +fn resolve_item_name(store: &fir::PackageStore, id: &ItemId) -> String { + let store_id = fir::StoreItemId { + package: id.package, + item: id.item, + }; + let item = store.get_item(store_id); + if let ItemKind::Callable(decl) = &item.kind { + decl.name.name.to_string() + } else { + format!("{id}") + } +} + +/// Formats a `FunctorApp` as a short specialization label. +fn functor_app_short(f: FunctorApp) -> &'static str { + match (f.adjoint, f.controlled) { + (false, 0) => "Body", + (true, 0) => "Adj", + (false, _) => "Ctl", + (true, _) => "CtlAdj", + } +} + +/// Formats a `ConcreteCallable` for snapshot display. +fn format_concrete_callable(cc: &ConcreteCallable, store: &fir::PackageStore) -> String { + match cc { + ConcreteCallable::Global { item_id, functor } => { + let name = resolve_item_name(store, item_id); + let spec = functor_app_short(*functor); + format!("{name}:{spec}") + } + ConcreteCallable::Closure { + target, functor, .. + } => { + let spec = functor_app_short(*functor); + format!("Closure({target}):{spec}") + } + ConcreteCallable::Dynamic => "Dynamic".to_string(), + } +} + +fn callable_param_display_path(param: &CallableParam) -> Vec { + std::iter::once(param.top_level_param) + .chain(param.field_path.iter().copied()) + .collect() +} + +/// Compiles Q# source, runs the defunctionalization pre-pass and analysis, and +/// snapshots the analysis results. +fn check_analysis(source: &str, expect: &Expect) { + check_analysis_with_capabilities(source, TargetCapabilityFlags::empty(), expect); +} + +fn check_analysis_with_capabilities( + source: &str, + capabilities: TargetCapabilityFlags, + expect: &Expect, +) { + let (mut fir_store, fir_pkg_id) = + compile_to_monomorphized_fir_with_capabilities(source, capabilities); + let reachable = collect_reachable_from_entry(&fir_store, fir_pkg_id); + let package = fir_store.get(fir_pkg_id); + let local_item_ids: Vec<_> = reachable_local_callables(package, fir_pkg_id, &reachable) + .map(|(id, _)| id) + .collect(); + let reachable_expr_ids = + collect_expr_ids_in_entry_and_local_callables(package, &local_item_ids); + super::prepass::run(&mut fir_store, fir_pkg_id, &reachable_expr_ids); + let result = defunc_analysis::analyze(&mut fir_store, fir_pkg_id, &reachable); + + let mut lines: Vec = Vec::new(); + lines.push(format!("callable_params: {}", result.callable_params.len())); + for param in &result.callable_params { + lines.push(format!( + " param: callable_id={}, path={:?}, ty={}", + param.callable_id, + callable_param_display_path(param), + param.param_ty + )); + } + lines.push(format!("call_sites: {}", result.call_sites.len())); + for cs in &result.call_sites { + let hof_name = resolve_item_name(&fir_store, &cs.hof_item_id); + let arg_desc = match &cs.callable_arg { + ConcreteCallable::Global { item_id, functor } => { + let name = resolve_item_name(&fir_store, item_id); + let spec = functor_app_short(*functor); + format!("Global({name}, {spec})") + } + ConcreteCallable::Closure { + target, functor, .. + } => { + let spec = functor_app_short(*functor); + format!("Closure(target={target}, {spec})") + } + ConcreteCallable::Dynamic => "Dynamic".to_string(), + }; + lines.push(format!(" site: hof={hof_name}, arg={arg_desc}")); + } + + let mut direct_call_site_lines: Vec<_> = result + .direct_call_sites + .iter() + .map(|site| { + let condition = site.condition.map_or_else( + || "default".to_string(), + |expr| format!("condition={expr:?}"), + ); + format!( + " site: callee={}, {condition}", + format_concrete_callable(&site.callable, &fir_store) + ) + }) + .collect(); + if !direct_call_site_lines.is_empty() { + lines.push(format!( + "direct_call_sites: {}", + direct_call_site_lines.len() + )); + direct_call_site_lines.sort(); + lines.extend(direct_call_site_lines); + } + + let mut lattice_items: Vec<_> = result.lattice_states.iter().collect(); + lattice_items.sort_by_key(|(id, _)| **id); + if !lattice_items.is_empty() { + lines.push("lattice states:".to_string()); + for (item_id, entries) in &lattice_items { + let callable_item_id = ItemId { + package: fir_pkg_id, + item: **item_id, + }; + let name = resolve_item_name(&fir_store, &callable_item_id); + lines.push(format!(" callable {name}:")); + for (var_id, lattice) in *entries { + let desc = match lattice { + CalleeLattice::Bottom => continue, + CalleeLattice::Single(cc) => { + format!("Single({})", format_concrete_callable(cc, &fir_store)) + } + CalleeLattice::Multi(candidates) => { + let names: Vec = candidates + .iter() + .map(|(cc, _)| format_concrete_callable(cc, &fir_store)) + .collect(); + format!("Multi([{}])", names.join(", ")) + } + CalleeLattice::Dynamic => "Dynamic".to_string(), + }; + lines.push(format!(" {var_id}: {desc}")); + } + } + } + + expect.assert_eq(&lines.join("\n")); +} + +/// Compiles Q# source, runs defunctionalization, and asserts `PostDefunc` +/// invariants hold. +fn check_invariants(source: &str) { + check_invariants_with_capabilities(source, TargetCapabilityFlags::empty()); +} + +fn check_invariants_with_capabilities(source: &str, capabilities: TargetCapabilityFlags) { + let (mut fir_store, fir_pkg_id) = + compile_to_monomorphized_fir_with_capabilities(source, capabilities); + let mut assigner = qsc_fir::assigner::Assigner::from_package(fir_store.get(fir_pkg_id)); + let errors = defunctionalize(&mut fir_store, fir_pkg_id, &mut assigner); + assert_no_defunctionalization_errors("defunctionalization", &errors); + fir_invariants::check(&fir_store, fir_pkg_id, InvariantLevel::PostDefunc); +} + +/// Compiles Q# source, runs defunctionalization, and snapshots the returned +/// error messages for comparison. +fn check_errors(source: &str, expect: &Expect) { + let (mut store, package_id) = compile_to_monomorphized_fir(source); + let mut assigner = qsc_fir::assigner::Assigner::from_package(store.get(package_id)); + let errors = defunctionalize(&mut store, package_id, &mut assigner); + expect.assert_eq(&format_defunctionalization_errors(&errors)); +} + +/// Compiles Q# source and runs the full FIR pipeline including monomorphization, +/// defunctionalization, and subsequent passes. +fn check_pipeline(source: &str) { + let (mut fir_store, fir_pkg_id) = crate::test_utils::compile_to_fir(source); + let errors = crate::run_pipeline(&mut fir_store, fir_pkg_id); + crate::test_utils::assert_no_pipeline_errors("run_pipeline", &errors); +} + +#[test] +fn compose_functors_identity() { + let a = FunctorApp::default(); + let b = FunctorApp::default(); + let result = compose_functors(&a, &b); + assert_eq!(result, FunctorApp::default()); +} + +#[test] +fn compose_functors_adj_toggle() { + let a = FunctorApp { + adjoint: true, + controlled: 0, + }; + let b = FunctorApp { + adjoint: true, + controlled: 0, + }; + let result = compose_functors(&a, &b); + assert!(!result.adjoint, "adj XOR adj should cancel"); + assert_eq!(result.controlled, 0); +} + +#[test] +fn compose_functors_ctl_stack() { + let a = FunctorApp { + adjoint: false, + controlled: 1, + }; + let b = FunctorApp { + adjoint: false, + controlled: 1, + }; + let result = compose_functors(&a, &b); + assert!(!result.adjoint); + assert_eq!(result.controlled, 2); +} + +#[test] +fn compose_functors_adj_and_ctl() { + let a = FunctorApp { + adjoint: true, + controlled: 1, + }; + let b = FunctorApp { + adjoint: false, + controlled: 1, + }; + let result = compose_functors(&a, &b); + assert!(result.adjoint, "true XOR false = true"); + assert_eq!(result.controlled, 2); +} + +#[test] +fn spec_key_equality() { + let key1 = SpecKey { + hof_id: LocalItemId::from(5usize), + concrete_args: vec![ConcreteCallableKey::Global { + item_id: ItemId { + package: fir::PackageId::from(1usize), + item: LocalItemId::from(10usize), + }, + functor: FunctorApp::default(), + }], + }; + let key2 = SpecKey { + hof_id: LocalItemId::from(5usize), + concrete_args: vec![ConcreteCallableKey::Global { + item_id: ItemId { + package: fir::PackageId::from(1usize), + item: LocalItemId::from(10usize), + }, + functor: FunctorApp::default(), + }], + }; + assert_eq!(key1, key2); +} + +#[test] +fn spec_key_different() { + let key1 = SpecKey { + hof_id: LocalItemId::from(5usize), + concrete_args: vec![ConcreteCallableKey::Global { + item_id: ItemId { + package: fir::PackageId::from(1usize), + item: LocalItemId::from(10usize), + }, + functor: FunctorApp::default(), + }], + }; + let key2 = SpecKey { + hof_id: LocalItemId::from(5usize), + concrete_args: vec![ConcreteCallableKey::Global { + item_id: ItemId { + package: fir::PackageId::from(1usize), + item: LocalItemId::from(20usize), + }, + functor: FunctorApp::default(), + }], + }; + assert_ne!(key1, key2); +} + +#[test] +fn spec_key_hash_consistent() { + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + + let key1 = SpecKey { + hof_id: LocalItemId::from(5usize), + concrete_args: vec![ConcreteCallableKey::Global { + item_id: ItemId { + package: fir::PackageId::from(1usize), + item: LocalItemId::from(10usize), + }, + functor: FunctorApp::default(), + }], + }; + let key2 = key1.clone(); + + let mut hasher1 = DefaultHasher::new(); + key1.hash(&mut hasher1); + let mut hasher2 = DefaultHasher::new(); + key2.hash(&mut hasher2); + assert_eq!(hasher1.finish(), hasher2.finish()); +} + +#[test] +fn concrete_callable_key_global() { + let key = ConcreteCallableKey::Global { + item_id: ItemId { + package: fir::PackageId::from(1usize), + item: LocalItemId::from(42usize), + }, + functor: FunctorApp { + adjoint: true, + controlled: 1, + }, + }; + match &key { + ConcreteCallableKey::Global { item_id, functor } => { + assert_eq!(item_id.item, LocalItemId::from(42usize)); + assert!(functor.adjoint); + assert_eq!(functor.controlled, 1); + } + ConcreteCallableKey::Closure { .. } => panic!("expected Global variant"), + } +} + +#[test] +fn concrete_callable_key_closure() { + let key = ConcreteCallableKey::Closure { + target: LocalItemId::from(7usize), + functor: FunctorApp { + adjoint: false, + controlled: 2, + }, + }; + match &key { + ConcreteCallableKey::Closure { target, functor } => { + assert_eq!(*target, LocalItemId::from(7usize)); + assert!(!functor.adjoint); + assert_eq!(functor.controlled, 2); + } + ConcreteCallableKey::Global { .. } => panic!("expected Closure variant"), + } +} + +#[test] +fn error_diagnostic_has_code() { + use miette::Diagnostic; + use qsc_data_structures::span::Span; + + let error = super::Error::DynamicCallable(Span::default()); + let code = error + .code() + .expect("DynamicCallable should have a diagnostic code"); + assert_eq!(code.to_string(), "Qsc.Defunctionalize.DynamicCallable"); +} + +#[test] +fn error_recursive_specialization() { + use miette::Diagnostic; + use qsc_data_structures::span::Span; + + let error = super::Error::RecursiveSpecialization(Span { lo: 42, hi: 50 }); + expect!["specialization leads to infinite recursion"].assert_eq(&error.to_string()); + let code = error + .code() + .expect("RecursiveSpecialization should have a diagnostic code"); + assert_eq!( + code.to_string(), + "Qsc.Defunctionalize.RecursiveSpecialization" + ); +} + +#[test] +fn empty_entrypoint_remains_unchanged() { + check( + "operation Main() : Unit { }", + &expect![[r#" + Main: input_ty=Unit"#]], + ); +} + +#[test] +fn test_helpers_surface_defunctionalization_errors() { + let source = r#" + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { op(q); } + operation Main() : Unit { + use q = Qubit(); + mutable op = H; + mutable n = 3; + while n > 0 { + op = X; + n -= 1; + } + ApplyOp(op, q); + } + "#; + + let check_panic = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + check(source, &expect![[r#"should not reach snapshot assertion"#]]); + })) + .expect_err("check should panic when defunctionalization returns errors"); + let check_message = panic_message(check_panic); + assert!( + check_message.contains("defunctionalization produced errors"), + "unexpected check panic: {check_message}" + ); + assert!( + check_message.contains("callable argument could not be resolved statically"), + "unexpected check panic: {check_message}" + ); + + let pipeline_panic = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + check_pipeline(source); + })) + .expect_err("check_pipeline should panic when run_pipeline returns defunctionalization errors"); + let pipeline_message = panic_message(pipeline_panic); + assert!( + pipeline_message.contains("produced FIR transform pipeline errors"), + "unexpected check_pipeline panic: {pipeline_message}" + ); + assert!( + pipeline_message.contains("callable argument could not be resolved statically"), + "unexpected check_pipeline panic: {pipeline_message}" + ); +} + +/// A HOF whose body defines a nested lambda (lifted to a +/// `StmtKind::Item` in FIR) must have that item included in the extracted +/// body package so that `FirCloner::clone_nested_item` can find it during +/// specialization. +#[test] +fn hof_with_nested_lambda_in_body_specializes_correctly() { + // `Transform` is a HOF (takes `f : Int -> Int`). Its body defines + // `helper` as a local lambda — the compiler lifts this to a nested + // item and references it via `StmtKind::Item` + `ExprKind::Closure`. + // When `Transform` is specialized for `x -> x + 1`, the body extraction + // must include the nested item or FirCloner will panic. + let source = r#" + function Transform(f : Int -> Int, x : Int) : Int { + let helper = y -> y * 2; + helper(f(x)) + } + function Main() : Int { + Transform(x -> x + 1, 5) + } + "#; + check_pipeline(source); +} + +/// A HOF whose body defines a *named* nested function +/// (which appears as `StmtKind::Item` in FIR) must have that item included +/// in the extracted body package for specialization to succeed. +#[test] +fn hof_with_nested_named_function_specializes_correctly() { + let source = r#" + function Transform(f : Int -> Int, x : Int) : Int { + function Helper(y : Int) : Int { y * 2 } + Helper(f(x)) + } + function Main() : Int { + Transform(x -> x + 1, 5) + } + "#; + check_pipeline(source); +} + +#[test] +fn unreachable_closure_structure_preserved() { + // Reachable: Main calls Apply with a closure. + // Dead: DeadFn uses a different closure pattern. + // Document whether the dead closure structure is mutated by defunctionalization. + use indoc::indoc; + let (mut fir_store, fir_pkg_id) = compile_to_monomorphized_fir(indoc! {" + namespace Test { + @EntryPoint() + operation Main() : Int { + Apply(x -> x + 1, 5) + } + function Apply(f : Int -> Int, x : Int) : Int { f(x) } + // Dead — never called from entry + function DeadFn() : Int { + Apply(x -> x * 2, 10) + } + } + "}); + let mut assigner = qsc_fir::assigner::Assigner::from_package(fir_store.get(fir_pkg_id)); + let errors = defunctionalize(&mut fir_store, fir_pkg_id, &mut assigner); + assert_no_defunctionalization_errors("unreachable_closure_structure_preserved", &errors); + + // Document that dead callable still exists (item DCE hasn't run yet). + let package = fir_store.get(fir_pkg_id); + let dead_exists = package.items.values().any(|item| { + matches!(&item.kind, ItemKind::Callable(decl) if decl.name.name.as_ref() == "DeadFn") + }); + assert!(dead_exists, "DeadFn should still exist pre-DCE"); +} diff --git a/source/compiler/qsc_fir_transforms/src/defunctionalize/tests/analysis.rs b/source/compiler/qsc_fir_transforms/src/defunctionalize/tests/analysis.rs new file mode 100644 index 0000000000..6c14fa8722 --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/defunctionalize/tests/analysis.rs @@ -0,0 +1,1213 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use crate::defunctionalize::analysis::{LocalState, resolve_captures}; + +use super::*; +use expect_test::expect; +use qsc_data_structures::index_map::IndexMap; +use qsc_fir::fir::{LocalVarId, Package}; +use rustc_hash::FxHashSet; + +#[test] +fn analysis_no_callable_params() { + check_analysis( + "operation Main() : Unit { }", + &expect![[r#" + callable_params: 0 + call_sites: 0"#]], + ); +} + +#[test] +fn analysis_single_callable_param() { + check_analysis( + r#" + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation Main() : Unit { + use q = Qubit(); + ApplyOp(H, q); + } + "#, + &expect![[r#" + callable_params: 1 + param: callable_id=3, path=[0], ty=(Qubit => Unit is Adj + Ctl) + call_sites: 1 + site: hof=ApplyOp, arg=Global(H, Body)"#]], + ); +} + +#[test] +fn analysis_multiple_callable_params() { + check_analysis( + r#" + operation ApplyTwo(f : Qubit => Unit, g : Qubit => Unit, q : Qubit) : Unit { + f(q); + g(q); + } + operation Main() : Unit { + use q = Qubit(); + ApplyTwo(H, X, q); + } + "#, + &expect![[r#" + callable_params: 2 + param: callable_id=3, path=[0], ty=(Qubit => Unit is Adj + Ctl) + param: callable_id=3, path=[1], ty=(Qubit => Unit is Adj + Ctl) + call_sites: 2 + site: hof=ApplyTwo, arg=Global(H, Body) + site: hof=ApplyTwo, arg=Global(X, Body)"#]], + ); +} + +#[test] +fn analysis_callable_param_in_tuple() { + check_analysis( + r#" + operation ApplySecond(q : Qubit, op : Qubit => Unit) : Unit { + op(q); + } + operation Main() : Unit { + use q = Qubit(); + ApplySecond(q, H); + } + "#, + &expect![[r#" + callable_params: 1 + param: callable_id=3, path=[1], ty=(Qubit => Unit is Adj + Ctl) + call_sites: 1 + site: hof=ApplySecond, arg=Global(H, Body)"#]], + ); +} + +#[test] +fn analysis_global_callable_arg() { + check_analysis( + r#" + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation Main() : Unit { + use q = Qubit(); + ApplyOp(X, q); + } + "#, + &expect![[r#" + callable_params: 1 + param: callable_id=4, path=[0], ty=(Qubit => Unit is Adj + Ctl) + call_sites: 1 + site: hof=ApplyOp, arg=Global(X, Body)"#]], + ); +} + +#[test] +fn analysis_closure_callable_arg() { + check_analysis( + r#" + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation Main() : Unit { + use q = Qubit(); + ApplyOp(q1 => H(q1), q); + } + "#, + &expect![[r#" + callable_params: 1 + param: callable_id=4, path=[0], ty=(Qubit => Unit) + call_sites: 1 + site: hof=ApplyOp, arg=Global(H, Body)"#]], + ); +} + +#[test] +fn analysis_adjoint_callable_arg() { + check_analysis( + r#" + operation ApplyOp(op : Qubit => Unit is Adj, q : Qubit) : Unit { + op(q); + } + operation Main() : Unit { + use q = Qubit(); + ApplyOp(Adjoint S, q); + } + "#, + &expect![[r#" + callable_params: 1 + param: callable_id=3, path=[0], ty=(Qubit => Unit is Adj + Ctl) + call_sites: 1 + site: hof=ApplyOp, arg=Global(S, Adj)"#]], + ); +} + +#[test] +fn analysis_controlled_callable_arg() { + check_analysis( + r#" + operation ApplyOp(op : (Qubit[], Qubit) => Unit is Ctl, q : Qubit) : Unit { + op([], q); + } + operation Main() : Unit { + use q = Qubit(); + ApplyOp(Controlled X, q); + } + "#, + &expect![[r#" + callable_params: 1 + param: callable_id=4, path=[0], ty=(((Qubit)[], Qubit) => Unit is Adj + Ctl) + call_sites: 1 + site: hof=ApplyOp, arg=Global(X, Ctl)"#]], + ); +} + +#[test] +fn analysis_multiple_call_sites_same_hof() { + check_analysis( + r#" + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation Main() : Unit { + use q = Qubit(); + ApplyOp(H, q); + ApplyOp(X, q); + } + "#, + &expect![[r#" + callable_params: 1 + param: callable_id=3, path=[0], ty=(Qubit => Unit is Adj + Ctl) + call_sites: 2 + site: hof=ApplyOp, arg=Global(H, Body) + site: hof=ApplyOp, arg=Global(X, Body)"#]], + ); +} + +#[test] +fn analysis_single_assignment_local_traced() { + check_analysis( + r#" + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation Main() : Unit { + use q = Qubit(); + let myH = H; + ApplyOp(myH, q); + } + "#, + &expect![[r#" + callable_params: 1 + param: callable_id=3, path=[0], ty=(Qubit => Unit is Adj + Ctl) + call_sites: 1 + site: hof=ApplyOp, arg=Global(H, Body) + lattice states: + callable Main: + 2: Single(H:Body)"#]], + ); +} + +#[test] +fn analysis_dynamic_callable_detected() { + check_analysis( + r#" + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation Main() : Unit { + use q = Qubit(); + mutable op = H; + op = X; + ApplyOp(op, q); + } + "#, + &expect![[r#" + callable_params: 1 + param: callable_id=3, path=[0], ty=(Qubit => Unit is Adj + Ctl) + call_sites: 1 + site: hof=ApplyOp, arg=Global(X, Body) + lattice states: + callable Main: + 2: Single(X:Body)"#]], + ); +} + +#[test] +fn udt_field_single_level_direct() { + check_analysis( + r#" + struct Config { Apply : Qubit => Unit } + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation Main() : Unit { + use q = Qubit(); + let config = new Config { Apply = H }; + ApplyOp(config.Apply, q); + } + "#, + &expect![[r#" + callable_params: 1 + param: callable_id=5, path=[0], ty=(Qubit => Unit) + call_sites: 1 + site: hof=ApplyOp, arg=Global(H, Body)"#]], + ); +} + +#[test] +fn udt_field_via_let_binding() { + check_analysis( + r#" + struct Config { Apply : Qubit => Unit } + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation Main() : Unit { + use q = Qubit(); + let config = new Config { Apply = H }; + let f = config.Apply; + ApplyOp(f, q); + } + "#, + &expect![[r#" + callable_params: 1 + param: callable_id=5, path=[0], ty=(Qubit => Unit) + call_sites: 1 + site: hof=ApplyOp, arg=Global(H, Body) + lattice states: + callable Main: + 3: Single(H:Body)"#]], + ); +} + +#[test] +fn udt_field_in_hof_body() { + check_analysis( + r#" + struct Config { Op : Qubit => Unit } + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation RunWithConfig(config : Config, q : Qubit) : Unit { + ApplyOp(config.Op, q); + } + operation Main() : Unit { + use q = Qubit(); + let config = new Config { Op = H }; + RunWithConfig(config, q); + } + "#, + &expect![[r#" + callable_params: 2 + param: callable_id=6, path=[0], ty=(Qubit => Unit) + param: callable_id=3, path=[0, 0], ty=(Qubit => Unit) + call_sites: 2 + site: hof=RunWithConfig, arg=Global(H, Body) + site: hof=ApplyOp, arg=Dynamic"#]], + ); +} + +#[test] +fn udt_field_in_hof_body_defunctionalizes_end_to_end() { + check( + r#" + struct Config { Op : Qubit => Unit } + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation RunWithConfig(config : Config, q : Qubit) : Unit { + ApplyOp(config.Op, q); + } + operation Main() : Unit { + use q = Qubit(); + let config = new Config { Op = H }; + RunWithConfig(config, q); + } + "#, + &expect![[r#" + ApplyOp{H}: input_ty=Qubit + Main: input_ty=Unit + RunWithConfig{H}: input_ty=(Unit, Qubit)"#]], + ); +} + +#[test] +fn udt_field_in_hof_body_full_pipeline_invariants() { + check_pipeline( + r#" + struct Config { Op : Qubit => Unit } + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation RunWithConfig(config : Config, q : Qubit) : Unit { + ApplyOp(config.Op, q); + } + operation Main() : Unit { + use q = Qubit(); + let config = new Config { Op = H }; + RunWithConfig(config, q); + } + "#, + ); +} + +#[test] +fn udt_field_nested_two_level() { + check_analysis( + r#" + struct InnerConfig { Apply : Qubit => Unit } + struct OuterConfig { Inner : InnerConfig, Label : Int } + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation Main() : Unit { + use q = Qubit(); + let outer = new OuterConfig { + Inner = new InnerConfig { Apply = H }, + Label = 0, + }; + ApplyOp(outer.Inner.Apply, q); + } + "#, + &expect![[r#" + callable_params: 1 + param: callable_id=6, path=[0], ty=(Qubit => Unit) + call_sites: 1 + site: hof=ApplyOp, arg=Global(H, Body)"#]], + ); +} + +#[test] +fn udt_field_nested_two_level_defunctionalizes_end_to_end() { + check( + r#" + struct InnerConfig { Apply : Qubit => Unit } + struct OuterConfig { Inner : InnerConfig, Label : Int } + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation Main() : Unit { + use q = Qubit(); + let outer = new OuterConfig { + Inner = new InnerConfig { Apply = H }, + Label = 0, + }; + ApplyOp(outer.Inner.Apply, q); + } + "#, + &expect![[r#" + ApplyOp{H}: input_ty=Qubit + Main: input_ty=Unit"#]], + ); +} + +#[test] +fn udt_field_closure_value() { + check_analysis( + r#" + struct Config { Op : Qubit => Unit } + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation Main() : Unit { + use q = Qubit(); + let angle = 1.0; + let config = new Config { Op = q1 => Rx(angle, q1) }; + ApplyOp(config.Op, q); + } + "#, + &expect![[r#" + callable_params: 1 + param: callable_id=6, path=[0], ty=(Qubit => Unit) + call_sites: 1 + site: hof=ApplyOp, arg=Closure(target=4, Body)"#]], + ); +} + +#[test] +fn udt_field_from_parameter_dynamic() { + check_analysis( + r#" + struct Config { Op : Qubit => Unit } + operation MakeConfig() : Config { + new Config { Op = H } + } + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation Main() : Unit { + use q = Qubit(); + let c = MakeConfig(); + ApplyOp(c.Op, q); + } + "#, + &expect![[r#" + callable_params: 1 + param: callable_id=6, path=[0], ty=(Qubit => Unit) + call_sites: 1 + site: hof=ApplyOp, arg=Dynamic"#]], + ); +} + +#[test] +fn identity_closure_over_global_callable_collapses() { + check_invariants( + r#" + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation Main() : Unit { + use q = Qubit(); + ApplyOp(a => H(a), q); + } + "#, + ); +} + +#[test] +fn identity_closure_wrapping_param() { + check_invariants( + r#" + operation Inner(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation Outer(action : Qubit => Unit, q : Qubit) : Unit { + Inner(a => action(a), q); + } + operation Main() : Unit { + use q = Qubit(); + Outer(H, q); + } + "#, + ); +} + +#[test] +fn non_identity_closure_preserved() { + check_analysis( + r#" + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation Main() : Unit { + use q = Qubit(); + ApplyOp(a => { H(a); X(a); }, q); + } + "#, + &expect![[r#" + callable_params: 1 + param: callable_id=4, path=[0], ty=(Qubit => Unit) + call_sites: 1 + site: hof=ApplyOp, arg=Closure(target=3, Body)"#]], + ); +} + +#[test] +fn identity_closure_tuple_args() { + check_invariants( + r#" + operation Pair(a : Qubit, b : Qubit) : Unit { + H(a); + H(b); + } + operation HOF2(op : (Qubit, Qubit) => Unit, q1 : Qubit, q2 : Qubit) : Unit { + op(q1, q2); + } + operation Main() : Unit { + use q1 = Qubit(); + use q2 = Qubit(); + HOF2((a, b) => Pair(a, b), q1, q2); + } + "#, + ); +} + +#[test] +fn closure_with_captures_not_identity() { + check_analysis( + r#" + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation Main() : Unit { + use q = Qubit(); + let angle = 1.0; + ApplyOp(a => Rx(angle, a), q); + } + "#, + &expect![[r#" + callable_params: 1 + param: callable_id=4, path=[0], ty=(Qubit => Unit) + call_sites: 1 + site: hof=ApplyOp, arg=Closure(target=3, Body)"#]], + ); +} + +#[test] +fn partial_application_lambda_analysis_shape() { + check( + r#" + operation ApplyOp(op : Qubit[] => Unit, register : Qubit[]) : Unit { + op(register); + } + operation Shifted(shift : Int, register : Qubit[]) : Unit { + ApplyXorInPlace(shift, register); + } + operation Main() : Unit { + use register = Qubit[2]; + ApplyOp(register => Shifted(1, register), register); + } + "#, + &expect![ + ": input_ty=((Qubit)[],)\nApplyOp{closure}: input_ty=(Qubit)[]\nMain: input_ty=Unit\nShifted: input_ty=(Int, (Qubit)[])" + ], + ); +} + +#[test] +fn reaching_def_mutable_single_assign() { + check_invariants( + r#" + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation Main() : Unit { + use q = Qubit(); + mutable op = H; + ApplyOp(op, q); + } + "#, + ); +} + +#[test] +fn reaching_def_conditional_both_known() { + check_invariants( + r#" + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation Main() : Unit { + use q = Qubit(); + let f = if true { H } else { X }; + ApplyOp(f, q); + } + "#, + ); +} + +#[test] +fn reaching_def_mutable_multi_assign() { + check_invariants( + r#" + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation Main() : Unit { + use q = Qubit(); + mutable op = H; + if true { set op = X; } + ApplyOp(op, q); + } + "#, + ); +} + +#[test] +fn reaching_def_mutable_both_branches() { + check_invariants( + r#" + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation Main() : Unit { + use q = Qubit(); + mutable op = H; + if true { set op = X; } else { set op = S; } + ApplyOp(op, q); + } + "#, + ); +} + +#[test] +fn reaching_def_mutable_in_loop_dynamic() { + check_errors( + r#" + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation Main() : Unit { + use q = Qubit(); + mutable op = H; + for _ in 0..3 { set op = X; } + ApplyOp(op, q); + } + "#, + &expect!["callable argument could not be resolved statically"], + ); +} + +#[test] +fn analysis_closure_through_multiple_levels() { + check_analysis( + r#" + operation Inner(op : Qubit => Unit, q : Qubit) : Unit { op(q); } + operation Outer(op : Qubit => Unit, q : Qubit) : Unit { Inner(op, q); } + operation Main() : Unit { + use q = Qubit(); + Outer(q1 => H(q1), q); + } + "#, + &expect![[r#" + callable_params: 2 + param: callable_id=5, path=[0], ty=(Qubit => Unit) + param: callable_id=7, path=[0], ty=(Qubit => Unit) + call_sites: 2 + site: hof=Inner, arg=Dynamic + site: hof=Outer, arg=Global(H, Body)"#]], + ); +} + +#[test] +fn analysis_callable_returned_from_function() { + check_analysis( + r#" + operation GetOp() : Qubit => Unit { H } + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { op(q); } + operation Main() : Unit { + use q = Qubit(); + let op = GetOp(); + ApplyOp(op, q); + } + "#, + &expect![ + "callable_params: 1\n param: callable_id=5, path=[0], ty=(Qubit => Unit)\ncall_sites: 1\n site: hof=ApplyOp, arg=Global(H, Body)\nlattice states:\n callable Main:\n 2: Single(H:Body)" + ], + ); +} + +#[test] +fn callable_from_function_return_resolves_statically() { + check_invariants( + r#" + function GetOp() : (Qubit => Unit) { H } + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation Main() : Unit { + use q = Qubit(); + ApplyOp(GetOp(), q); + } + "#, + ); +} + +#[test] +fn callable_returning_partial_application_resolves_statically() { + check_invariants( + r#" + operation ApplyOp(op : ((Qubit[], Qubit) => Unit), register : Qubit[], target : Qubit) : Unit { + op(register, target); + } + + operation ApplyParityOperation(bits : Bool[], register : Qubit[], target : Qubit) : Unit { + if bits[0] { + CNOT(register[0], target); + } + } + + operation MakeParity(bits : Bool[]) : (Qubit[], Qubit) => Unit { + return ApplyParityOperation(bits, _, _); + } + + operation Main() : Unit { + use register = Qubit[1]; + use target = Qubit(); + let op = MakeParity([true]); + ApplyOp(op, register, target); + } + "#, + ); +} + +#[test] +fn analysis_callable_returning_partial_application_with_explicit_return() { + check_analysis( + r#" + operation ApplyOp(op : ((Qubit[], Qubit) => Unit), register : Qubit[], target : Qubit) : Unit { + op(register, target); + } + + operation ApplyParityOperation(bits : Bool[], register : Qubit[], target : Qubit) : Unit { + if bits[0] { + CNOT(register[0], target); + } + } + + operation MakeParity(bits : Bool[]) : (Qubit[], Qubit) => Unit { + return ApplyParityOperation(bits, _, _); + } + + operation Main() : Unit { + use register = Qubit[1]; + use target = Qubit(); + let op = MakeParity([true]); + ApplyOp(op, register, target); + } + "#, + &expect![ + "callable_params: 1\n param: callable_id=7, path=[0], ty=(((Qubit)[], Qubit) => Unit)\ncall_sites: 1\n site: hof=ApplyOp, arg=Closure(target=5, Body)\nlattice states:\n callable Main:\n 3: Single(Closure(5):Body)" + ], + ); +} + +#[test] +fn callable_returning_partial_application_from_local_arg_preserves_capture_expr() { + check_invariants( + r#" + operation UseOracle(oracle : ((Qubit[], Qubit) => Unit), n : Int) : Unit { + use register = Qubit[n]; + use target = Qubit(); + oracle(register, target); + Reset(target); + ResetAll(register); + } + + operation ApplyParityOperation(bits : Bool[], register : Qubit[], target : Qubit) : Unit { + if bits[0] { + CNOT(register[0], target); + } + } + + operation Encode(bits : Bool[]) : (Qubit[], Qubit) => Unit { + ApplyParityOperation(bits, _, _) + } + + operation Main() : Unit { + let bits = [true]; + let oracle = Encode(bits); + UseOracle(oracle, Length(bits)); + } + "#, + ); +} + +#[test] +fn callable_from_array_index_resolves_statically() { + check_invariants( + r#" + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { op(q); } + operation Main() : Unit { + use q = Qubit(); + let ops = [H, X]; + ApplyOp(ops[0], q); + } + "#, + ); +} + +#[test] +fn callable_returning_partial_application_from_function_resolves_statically() { + check_invariants( + r#" + operation ApplyOp(op : ((Qubit[], Qubit) => Unit), register : Qubit[], target : Qubit) : Unit { + op(register, target); + } + + operation ApplyParityOperation(value : Int, register : Qubit[], target : Qubit) : Unit { + if value == 1 { + CNOT(register[0], target); + } + } + + function Encode(value : Int) : (Qubit[], Qubit) => Unit { + return ApplyParityOperation(value, _, _); + } + + operation Main() : Unit { + use register = Qubit[1]; + use target = Qubit(); + let value = 1; + let oracle = Encode(value); + ApplyOp(oracle, register, target); + } + "#, + ); +} + +#[test] +fn analysis_callable_from_constant_callable_array_loop() { + check_analysis( + r#" + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + + operation Main() : Unit { + use q = Qubit(); + let ops = [H, X]; + for op in ops { + ApplyOp(op, q); + } + } + "#, + &expect![ + "callable_params: 1\n param: callable_id=4, path=[0], ty=(Qubit => Unit is Adj + Ctl)\ncall_sites: 2\n site: hof=ApplyOp, arg=Global(H, Body)\n site: hof=ApplyOp, arg=Global(X, Body)\nlattice states:\n callable Main:\n 7: Multi([H:Body, X:Body])" + ], + ); +} + +#[test] +fn analysis_callable_returning_partial_application_from_function_in_loop() { + check_analysis( + r#" + operation ApplyOp(op : ((Qubit[], Qubit) => Unit), register : Qubit[], target : Qubit) : Unit { + op(register, target); + } + + operation ApplyParityOperation(value : Int, register : Qubit[], target : Qubit) : Unit { + if value == 1 { + CNOT(register[0], target); + } + } + + function Encode(value : Int) : (Qubit[], Qubit) => Unit { + return ApplyParityOperation(value, _, _); + } + + operation Main() : Unit { + use register = Qubit[1]; + use target = Qubit(); + for value in [1, 2] { + let oracle = Encode(value); + ApplyOp(oracle, register, target); + } + } + "#, + &expect![ + "callable_params: 1\n param: callable_id=8, path=[0], ty=(((Qubit)[], Qubit) => Unit)\ncall_sites: 1\n site: hof=ApplyOp, arg=Closure(target=5, Body)\nlattice states:\n callable Main:\n 8: Single(Closure(5):Body)" + ], + ); +} + +#[test] +fn reaching_def_mutable_in_while_loop() { + check_errors( + r#" + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { op(q); } + operation Main() : Unit { + use q = Qubit(); + mutable op = H; + mutable n = 3; + while n > 0 { + op = X; + n -= 1; + } + ApplyOp(op, q); + } + "#, + &expect!["callable argument could not be resolved statically"], + ); +} + +#[test] +fn analysis_nested_callable_in_tuple_param() { + check_analysis( + r#" + operation Wrapper(pair : (Qubit => Unit, Int), q : Qubit) : Unit { + let (op, _) = pair; + op(q); + } + operation Main() : Unit { + use q = Qubit(); + Wrapper((H, 42), q); + } + "#, + &expect![ + "callable_params: 1\n param: callable_id=3, path=[0, 0], ty=(Qubit => Unit is Adj + Ctl)\ncall_sites: 1\n site: hof=Wrapper, arg=Global(H, Body)" + ], + ); +} + +#[test] +fn analysis_nested_callable_second_element() { + check_analysis( + r#" + operation Wrapper(pair : (Int, Qubit => Unit), q : Qubit) : Unit { + let (_, op) = pair; + op(q); + } + operation Main() : Unit { + use q = Qubit(); + Wrapper((42, H), q); + } + "#, + &expect![[r#" + callable_params: 1 + param: callable_id=3, path=[0, 1], ty=(Qubit => Unit is Adj + Ctl) + call_sites: 1 + site: hof=Wrapper, arg=Global(H, Body)"#]], + ); +} + +#[test] +fn analysis_nested_callable_single_param_supported() { + check_analysis( + r#" + operation Wrapper(pair : (Qubit => Unit, Int)) : Unit { + let (op, _) = pair; + use q = Qubit(); + op(q); + } + operation Main() : Unit { + Wrapper((H, 42)); + } + "#, + &expect![ + "callable_params: 1\n param: callable_id=3, path=[0, 0], ty=(Qubit => Unit is Adj + Ctl)\ncall_sites: 1\n site: hof=Wrapper, arg=Global(H, Body)" + ], + ); +} + +#[test] +fn analysis_branch_split_nested_callable_in_tuple() { + check_analysis( + r#" + operation Wrapper(pair : (Qubit => Unit, Int), q : Qubit) : Unit { + let (op, _) = pair; + op(q); + } + operation Main() : Unit { + use q = Qubit(); + let f = if true { H } else { X }; + Wrapper((f, 42), q); + } + "#, + &expect![[r#" + callable_params: 1 + param: callable_id=3, path=[0, 0], ty=(Qubit => Unit is Adj + Ctl) + call_sites: 2 + site: hof=Wrapper, arg=Global(H, Body) + site: hof=Wrapper, arg=Global(X, Body) + lattice states: + callable Main: + 2: Multi([H:Body, X:Body])"#]], + ); +} + +#[test] +fn analysis_nested_callable_single_param_second_element_supported() { + check_analysis( + r#" + operation Wrapper(pair : (Int, Qubit => Unit)) : Unit { + let (_, op) = pair; + use q = Qubit(); + op(q); + } + operation Main() : Unit { + Wrapper((42, H)); + } + "#, + &expect![[r#" + callable_params: 1 + param: callable_id=3, path=[0, 1], ty=(Qubit => Unit is Adj + Ctl) + call_sites: 1 + site: hof=Wrapper, arg=Global(H, Body)"#]], + ); +} + +#[test] +fn analysis_nested_callable_single_param_recursive_supported() { + check_analysis( + r#" + operation Wrapper(bundle : (((Qubit => Unit, Int), Double), Qubit)) : Unit { + let (((op, n), angle), q) = bundle; + let _ = n; + let _ = angle; + op(q); + } + operation Main() : Unit { + use q = Qubit(); + Wrapper((((H, 42), 1.0), q)); + } + "#, + &expect![ + "callable_params: 1\n param: callable_id=3, path=[0, 0, 0, 0], ty=(Qubit => Unit is Adj + Ctl)\ncall_sites: 1\n site: hof=Wrapper, arg=Global(H, Body)" + ], + ); +} + +#[test] +fn identity_closure_adjoint_wrapped_collapses() { + check_analysis( + r#" + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation Main() : Unit { + use q = Qubit(); + ApplyOp(q1 => Adjoint S(q1), q); + } + "#, + &expect![[r#" + callable_params: 1 + param: callable_id=4, path=[0], ty=(Qubit => Unit) + call_sites: 1 + site: hof=ApplyOp, arg=Global(S, Adj) + direct_call_sites: 1 + site: callee=S:Adj, default"#]], + ); +} + +#[test] +fn single_use_immutable_local_promoted() { + check_analysis( + r#" + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation Main() : Unit { + use q = Qubit(); + let op = H; + ApplyOp(op, q); + } + "#, + &expect![[r#" + callable_params: 1 + param: callable_id=3, path=[0], ty=(Qubit => Unit is Adj + Ctl) + call_sites: 1 + site: hof=ApplyOp, arg=Global(H, Body) + lattice states: + callable Main: + 2: Single(H:Body)"#]], + ); +} + +#[test] +fn multi_use_immutable_local_not_promoted() { + check_analysis( + r#" + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation Main() : Unit { + use q1 = Qubit(); + use q2 = Qubit(); + let op = H; + ApplyOp(op, q1); + ApplyOp(op, q2); + } + "#, + &expect![[r#" + callable_params: 1 + param: callable_id=3, path=[0], ty=(Qubit => Unit is Adj + Ctl) + call_sites: 2 + site: hof=ApplyOp, arg=Global(H, Body) + site: hof=ApplyOp, arg=Global(H, Body) + lattice states: + callable Main: + 3: Single(H:Body)"#]], + ); +} + +#[test] +fn mutable_local_not_promoted() { + check_analysis( + r#" + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation Main() : Unit { + use q = Qubit(); + mutable op = H; + op = X; + ApplyOp(op, q); + } + "#, + &expect![[r#" + callable_params: 1 + param: callable_id=3, path=[0], ty=(Qubit => Unit is Adj + Ctl) + call_sites: 1 + site: hof=ApplyOp, arg=Global(X, Body) + lattice states: + callable Main: + 2: Single(X:Body)"#]], + ); +} + +#[test] +fn analysis_conditional_callable_binding_produces_multi_lattice() { + check_analysis( + r#" + operation ApplyConditional(power : Int, target : Qubit) : Unit { + let u = if power >= 0 { S } else { Adjoint S }; + u(target); + } + + operation Main() : Unit { + use q = Qubit(); + ApplyConditional(3, q); + } + "#, + &expect![[r#" + callable_params: 0 + call_sites: 0 + direct_call_sites: 2 + site: callee=S:Adj, default + site: callee=S:Body, condition=ExprId(4) + lattice states: + callable ApplyConditional: + 3: Multi([S:Body, S:Adj])"#]], + ); +} + +#[test] +fn analysis_callable_from_tuple_destructured_array_iteration() { + check_analysis( + r#" + namespace Test { + @EntryPoint() + operation Main() : Unit { + let arr = [(S, PauliZ), (T, PauliX)]; + for (op, _basis) in arr { + use q = Qubit(); + op(q); + } + } + } + "#, + &expect![[r#" + callable_params: 0 + call_sites: 0 + direct_call_sites: 2 + site: callee=S:Body, default + site: callee=T:Body, default + lattice states: + callable Main: + 5: Multi([S:Body, T:Body])"#]], + ); +} + +#[test] +fn resolve_captures_missing_binding_returns_none() { + let package = Package { + items: IndexMap::new(), + entry: None, + entry_exec_graph: qsc_fir::fir::ExecGraph::default(), + blocks: IndexMap::new(), + exprs: IndexMap::new(), + pats: IndexMap::new(), + stmts: IndexMap::new(), + }; + let locals = LocalState::default(); + let missing_var = LocalVarId::from(99usize); + + let captures = resolve_captures(&package, &locals, &[missing_var], &FxHashSet::default()); + + assert!( + captures.is_none(), + "missing capture bindings should degrade analysis instead of panicking" + ); +} diff --git a/source/compiler/qsc_fir_transforms/src/defunctionalize/tests/cross_package.rs b/source/compiler/qsc_fir_transforms/src/defunctionalize/tests/cross_package.rs new file mode 100644 index 0000000000..12c8f2fbcd --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/defunctionalize/tests/cross_package.rs @@ -0,0 +1,436 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use super::*; +use expect_test::expect; + +#[test] +fn analysis_apply_operation_power_ca_consumer() { + check_analysis_with_capabilities( + r#" + operation Consume(apply_power_of_u : (Int, Qubit[]) => Unit is Adj + Ctl, target : Qubit[]) : Unit { + apply_power_of_u(1, target); + } + + operation U(qs : Qubit[]) : Unit is Adj + Ctl { + H(qs[0]); + } + + operation Main() : Unit { + use qs = Qubit[1]; + Consume(ApplyOperationPowerCA(_, U, _), qs); + } + "#, + adaptive_qirgen_capabilities(), + &expect![[r#" + callable_params: 3 + param: callable_id=4, path=[0], ty=((Qubit)[] => Unit is Adj + Ctl) + param: callable_id=6, path=[1], ty=((Qubit)[] => Unit is Adj + Ctl) + param: callable_id=7, path=[0], ty=((Int, (Qubit)[]) => Unit is Adj + Ctl) + call_sites: 5 + site: hof=ApplyOperationPowerCA<(Qubit)[], AdjCtl>, arg=Dynamic + site: hof=ApplyOperationPowerCA<(Qubit)[], AdjCtl>, arg=Dynamic + site: hof=ApplyOperationPowerCA<(Qubit)[], AdjCtl>, arg=Dynamic + site: hof=ApplyOperationPowerCA<(Qubit)[], AdjCtl>, arg=Dynamic + site: hof=Consume, arg=Closure(target=4, Body) + direct_call_sites: 3 + site: callee=H:Adj, default + site: callee=H:Ctl, default + site: callee=H:CtlAdj, default + lattice states: + callable ApplyOperationPowerCA<(Qubit)[], AdjCtl>: + 3: Dynamic + 8: Dynamic + 15: Dynamic + 21: Dynamic"#]], + ); +} + +#[test] +fn analysis_bernstein_vazirani_sample_shape() { + check_analysis_with_capabilities( + r#" + import Std.Arrays.*; + import Std.Convert.*; + import Std.Diagnostics.*; + import Std.Math.*; + import Std.Measurement.*; + + operation Main() : Unit { + let nQubits = 10; + let integers = [127, 238, 512]; + for integer in integers { + let parityOperation = EncodeIntegerAsParityOperation(integer); + let _ = BernsteinVazirani(parityOperation, nQubits); + } + } + + operation BernsteinVazirani(Uf : ((Qubit[], Qubit) => Unit), n : Int) : Result[] { + use queryRegister = Qubit[n]; + use target = Qubit(); + X(target); + within { + ApplyToEachA(H, queryRegister); + } apply { + H(target); + Uf(queryRegister, target); + } + let resultArray = MResetEachZ(queryRegister); + Reset(target); + resultArray + } + + operation ApplyParityOperation(bitStringAsInt : Int, xRegister : Qubit[], yQubit : Qubit) : Unit { + let requiredBits = BitSizeI(bitStringAsInt); + let availableQubits = Length(xRegister); + Fact(availableQubits >= requiredBits, "enough qubits"); + for index in IndexRange(xRegister) { + if ((bitStringAsInt &&& 2^index) != 0) { + CNOT(xRegister[index], yQubit); + } + } + } + + function EncodeIntegerAsParityOperation(bitStringAsInt : Int) : (Qubit[], Qubit) => Unit { + return ApplyParityOperation(bitStringAsInt, _, _); + } + "#, + adaptive_qirgen_capabilities(), + &expect![[r#" + callable_params: 2 + param: callable_id=10, path=[0], ty=(((Qubit)[], Qubit) => Unit) + param: callable_id=6, path=[0], ty=(Qubit => Unit is Adj + Ctl) + call_sites: 3 + site: hof=ApplyToEachA, arg=Global(H, Body) + site: hof=ApplyToEachA, arg=Global(H, Body) + site: hof=BernsteinVazirani, arg=Closure(target=5, Body) + lattice states: + callable Main: + 7: Single(Closure(5):Body)"#]], + ); +} + +#[test] +fn analysis_deutsch_jozsa_sample_shape() { + check_analysis_with_capabilities( + r#" + import Std.Diagnostics.*; + import Std.Math.*; + import Std.Measurement.*; + + operation Main() : Unit { + let functionsToTest = [SimpleConstantBoolF, SimpleBalancedBoolF, ConstantBoolF, BalancedBoolF]; + for fn in functionsToTest { + let _ = DeutschJozsa(fn, 5); + } + } + + operation DeutschJozsa(Uf : ((Qubit[], Qubit) => Unit), n : Int) : Bool { + use queryRegister = Qubit[n]; + use target = Qubit(); + X(target); + H(target); + within { + for q in queryRegister { + H(q); + } + } apply { + Uf(queryRegister, target); + } + mutable result = true; + for q in queryRegister { + if MResetZ(q) == One { + result = false; + } + } + Reset(target); + result + } + + operation SimpleConstantBoolF(args : Qubit[], target : Qubit) : Unit { + X(target); + } + + operation SimpleBalancedBoolF(args : Qubit[], target : Qubit) : Unit { + CX(args[0], target); + } + + operation ConstantBoolF(args : Qubit[], target : Qubit) : Unit { + for i in 0..(2^Length(args)) - 1 { + ApplyControlledOnInt(i, X, args, target); + } + } + + operation BalancedBoolF(args : Qubit[], target : Qubit) : Unit { + for i in 0..2..(2^Length(args)) - 1 { + ApplyControlledOnInt(i, X, args, target); + } + } + "#, + adaptive_qirgen_capabilities(), + &expect![[r#" + callable_params: 2 + param: callable_id=8, path=[1], ty=(Qubit => Unit is Adj + Ctl) + param: callable_id=10, path=[0], ty=(((Qubit)[], Qubit) => Unit) + call_sites: 6 + site: hof=ApplyControlledOnInt, arg=Global(X, Body) + site: hof=ApplyControlledOnInt, arg=Global(X, Body) + site: hof=DeutschJozsa, arg=Global(SimpleConstantBoolF, Body) + site: hof=DeutschJozsa, arg=Global(SimpleBalancedBoolF, Body) + site: hof=DeutschJozsa, arg=Global(ConstantBoolF, Body) + site: hof=DeutschJozsa, arg=Global(BalancedBoolF, Body) + direct_call_sites: 5 + site: callee=ApplyPauliFromInt:Adj, default + site: callee=ApplyPauliFromInt:Adj, default + site: callee=ApplyPauliFromInt:Adj, default + site: callee=ApplyPauliFromInt:Adj, default + site: callee=H:Adj, default + lattice states: + callable Main: + 5: Multi([SimpleConstantBoolF:Body, SimpleBalancedBoolF:Body, ConstantBoolF:Body, BalancedBoolF:Body])"#]], + ); +} + +#[test] +fn full_pipeline_handles_stdlib_apply_to_each() { + check_pipeline( + r#" + open Std.Canon; + operation Main() : Unit { + use qs = Qubit[3]; + ApplyToEach(H, qs); + } + "#, + ); +} + +#[test] +fn full_pipeline_handles_stdlib_apply_to_each_with_custom_intrinsic() { + check_pipeline( + r#" + open Std.Canon; + operation Main() : Unit { + use qs = Qubit[3]; + ApplyToEach(SX, qs); + } + "#, + ); +} + +#[test] +fn apply_to_each_body_callable_defunctionalizes() { + check_invariants( + r#" + open Std.Canon; + operation Main() : Unit { + use qs = Qubit[3]; + ApplyToEach(H, qs); + } + "#, + ); +} + +#[test] +fn apply_to_each_a_adjoint_callable_defunctionalizes() { + check_invariants( + r#" + open Std.Canon; + operation Main() : Unit { + use qs = Qubit[3]; + ApplyToEachA(S, qs); + Adjoint ApplyToEachA(S, qs); + } + "#, + ); +} + +#[test] +fn apply_to_each_c_controlled_callable_defunctionalizes() { + check_invariants( + r#" + open Std.Canon; + operation Main() : Unit { + use (ctl, qs) = (Qubit(), Qubit[3]); + ApplyToEachC(X, qs); + } + "#, + ); +} + +#[test] +fn apply_to_each_ca_callable_defunctionalizes() { + check_invariants( + r#" + open Std.Canon; + operation Main() : Unit { + use qs = Qubit[3]; + ApplyToEachCA(S, qs); + } + "#, + ); +} + +#[test] +fn cross_package_apply_to_each_closure_arg_defunctionalizes() { + check_invariants( + r#" + open Std.Canon; + operation Main() : Unit { + use qs = Qubit[3]; + let angle = 1.0; + ApplyToEach(q => Rx(angle, q), qs); + } + "#, + ); +} + +#[test] +fn cross_package_apply_to_each_adjoint_arg_defunctionalizes() { + check_invariants( + r#" + open Std.Canon; + operation Main() : Unit { + use qs = Qubit[3]; + ApplyToEach(Adjoint S, qs); + } + "#, + ); +} + +#[test] +fn adjoint_cross_package_apply_to_each_ca_defunctionalizes() { + check_invariants( + r#" + open Std.Canon; + operation Main() : Unit { + use qs = Qubit[3]; + Adjoint ApplyToEachCA(S, qs); + } + "#, + ); +} + +#[test] +fn controlled_apply_to_each_ca_keeps_body_callable_static() { + check_pipeline( + r#" + open Std.Canon; + + operation PrepareUniform(inputQubits : Qubit[]) : Unit is Adj + Ctl { + ApplyToEachCA(H, inputQubits); + } + + operation PrepareAllOnes(inputQubits : Qubit[]) : Unit is Adj + Ctl { + ApplyToEachCA(X, inputQubits); + } + + @EntryPoint() + operation Main() : Unit { + use qs = Qubit[3]; + let register = [qs[1], qs[2]]; + Controlled PrepareUniform([qs[0]], register); + Controlled PrepareAllOnes([qs[0]], register); + } + "#, + ); +} + +#[test] +fn cross_package_mapped_defunctionalizes() { + check_pipeline( + r#" + open Std.Arrays; + function Double(x : Int) : Int { x * 2 } + @EntryPoint() + operation Main() : Unit { + let arr = [1, 2, 3]; + let _ = Mapped(Double, arr); + } + "#, + ); +} + +#[test] +fn cross_package_for_each_defunctionalizes() { + check_pipeline( + r#" + open Std.Arrays; + operation Main() : Unit { + use qs = Qubit[3]; + ForEach(H, qs); + } + "#, + ); +} + +#[test] +fn stdlib_hof_specialized_with_concrete_callable() { + check( + r#" + open Microsoft.Quantum.Arrays; + + operation Main() : Int[] { + let arr = [1, 2, 3]; + Mapped(x -> x + 1, arr) + } + "#, + &expect![[r#" + : input_ty=(Int,) + Length: input_ty=(Int)[] + Main: input_ty=Unit + Mapped{closure}: input_ty=(Int)[]"#]], + ); +} + +#[test] +fn lambda_expression_sample_shape_has_no_defunctionalization_errors() { + check_errors( + r#" + import Std.Arrays.*; + + operation Main() : Unit { + let add = (x, y) -> x + y; + let _ = add(2, 3); + + use control = Qubit(); + let cnotOnControl = q => CNOT(control, q); + + let intArray = [1, 2, 3, 4, 5]; + let _ = Fold(add, 0, intArray); + let _ = Mapped(x -> x + 1, intArray); + } + "#, + &expect!["(no error)"], + ); +} + +#[test] +fn partial_application_sample_shape_has_no_defunctionalization_errors() { + check_errors( + r#" + import Std.Arrays.*; + + function Main() : Unit { + let incrementByOne = Add(_, 1); + let incrementByOneLambda = x -> Add(x, 1); + + let _ = incrementByOne(4); + + let sumAndAddOne = AddMany(_, _, _, 1); + let sumAndAddOneLambda = (a, b, c) -> AddMany(a, b, c, 1); + + let intArray = [1, 2, 3, 4, 5]; + let _ = Mapped(Add(_, 1), intArray); + } + + function Add(x : Int, y : Int) : Int { + return x + y; + } + + function AddMany(a : Int, b : Int, c : Int, d : Int) : Int { + return a + b + c + d; + } + "#, + &expect!["(no error)"], + ); +} diff --git a/source/compiler/qsc_fir_transforms/src/defunctionalize/tests/fixpoint.rs b/source/compiler/qsc_fir_transforms/src/defunctionalize/tests/fixpoint.rs new file mode 100644 index 0000000000..bb4aafe945 --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/defunctionalize/tests/fixpoint.rs @@ -0,0 +1,656 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use super::*; +use expect_test::expect; +use std::fmt::Write; + +#[test] +fn program_without_hofs_converges_without_changes() { + check( + r#" + operation Main() : Unit { + use q = Qubit(); + H(q); + } + "#, + &expect![[r#" + Main: input_ty=Unit"#]], + ); +} + +#[test] +fn fixpoint_no_hof_call_sites_prunes_dead_callable_local_chain() { + check_invariants( + r#" + operation Main() : Unit { + let first : Int -> Bool = (value) -> value == 0; + let second : Int -> Bool = first; + } + "#, + ); +} + +#[test] +fn fixpoint_multi_level_hof() { + check_invariants( + r#" + operation ApplyInner(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation ApplyOuter(op : Qubit => Unit, q : Qubit) : Unit { + ApplyInner(op, q); + } + operation Main() : Unit { + use q = Qubit(); + ApplyOuter(H, q); + } + "#, + ); +} + +#[test] +fn invariant_after_fixpoint() { + check_invariants( + r#" + operation Inner(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation Outer(op : Qubit => Unit, q : Qubit) : Unit { + Inner(op, q); + } + operation Main() : Unit { + use q = Qubit(); + Outer(H, q); + } + "#, + ); +} + +#[test] +fn full_pipeline_succeeds_for_simple_hof() { + check_pipeline( + r#" + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation Main() : Unit { + use q = Qubit(); + ApplyOp(H, q); + } + "#, + ); +} + +#[test] +fn nested_hof_two_levels() { + check_invariants( + r#" + operation Level1(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation Level2(op : Qubit => Unit, q : Qubit) : Unit { + Level1(op, q); + } + operation Main() : Unit { + use q = Qubit(); + Level2(H, q); + } + "#, + ); +} + +#[test] +fn nested_hof_convergence() { + check_invariants( + r#" + operation L1(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation L2(op : Qubit => Unit, q : Qubit) : Unit { + L1(op, q); + } + operation L3(op : Qubit => Unit, q : Qubit) : Unit { + L2(op, q); + } + operation Main() : Unit { + use q = Qubit(); + L3(H, q); + } + "#, + ); +} + +#[test] +fn nested_hof_forwarding_with_adjoint() { + check_invariants( + r#" + operation Inner(op : Qubit => Unit is Adj, q : Qubit) : Unit { + op(q); + } + operation Outer(op : Qubit => Unit is Adj, q : Qubit) : Unit { + Inner(Adjoint op, q); + } + operation Main() : Unit { + use q = Qubit(); + Outer(S, q); + } + "#, + ); +} + +#[test] +fn nested_hof_controlled_forwarding() { + check_invariants( + r#" + operation Inner(op : Qubit => Unit is Ctl, q : Qubit) : Unit { + op(q); + } + operation Outer(op : Qubit => Unit is Ctl, q : Qubit) : Unit { + Inner(op, q); + } + operation Main() : Unit { + use q = Qubit(); + Outer(X, q); + } + "#, + ); +} + +#[test] +fn nested_hof_four_levels() { + check_invariants( + r#" + operation L1(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation L2(op : Qubit => Unit, q : Qubit) : Unit { + L1(op, q); + } + operation L3(op : Qubit => Unit, q : Qubit) : Unit { + L2(op, q); + } + operation L4(op : Qubit => Unit, q : Qubit) : Unit { + L3(op, q); + } + operation Main() : Unit { + use q = Qubit(); + L4(H, q); + } + "#, + ); +} + +#[test] +fn nested_hof_two_call_sites_different_args() { + check_invariants( + r#" + operation Inner(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation Outer(op : Qubit => Unit, q : Qubit) : Unit { + Inner(op, q); + } + operation Main() : Unit { + use q = Qubit(); + Outer(H, q); + Outer(X, q); + } + "#, + ); +} + +#[test] +fn nested_hof_forwarding_adj_autogen() { + check_invariants( + r#" + operation Inner(op : Qubit => Unit is Adj, q : Qubit) : Unit is Adj { + op(q); + } + operation Outer(op : Qubit => Unit is Adj, q : Qubit) : Unit is Adj { + Inner(op, q); + } + operation Main() : Unit { + use q = Qubit(); + Outer(S, q); + Adjoint Outer(S, q); + } + "#, + ); +} + +#[test] +fn nested_hof_requires_multi_iteration_convergence() { + check( + r#" + operation ApplyTwice(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + op(q); + } + + operation ApplyAndMeasure(action : (Qubit => Unit, Qubit) => Unit, op : Qubit => Unit, q : Qubit) : Result { + action(op, q); + M(q) + } + + operation Main() : Result { + use q = Qubit(); + ApplyAndMeasure(ApplyTwice, H, q) + } + "#, + &expect![[r#" + ApplyAndMeasure{ApplyTwice}{H}: input_ty=Qubit + ApplyTwice{H}: input_ty=Qubit + Main: input_ty=Unit"#]], + ); + check_invariants( + r#" + operation ApplyTwice(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + op(q); + } + + operation ApplyAndMeasure(action : (Qubit => Unit, Qubit) => Unit, op : Qubit => Unit, q : Qubit) : Result { + action(op, q); + M(q) + } + + operation Main() : Result { + use q = Qubit(); + ApplyAndMeasure(ApplyTwice, H, q) + } + "#, + ); +} + +#[test] +fn five_level_hof_chain_converges_at_max_iterations_boundary() { + check_invariants( + r#" + operation L1(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation L2(op : Qubit => Unit, q : Qubit) : Unit { + L1(op, q); + } + operation L3(op : Qubit => Unit, q : Qubit) : Unit { + L2(op, q); + } + operation L4(op : Qubit => Unit, q : Qubit) : Unit { + L3(op, q); + } + operation L5(op : Qubit => Unit, q : Qubit) : Unit { + L4(op, q); + } + operation Main() : Unit { + use q = Qubit(); + L5(H, q); + } + "#, + ); +} + +#[test] +fn transient_dynamic_resolves_after_outer_hof_specialization() { + check_errors( + r#" + operation ApplyInner(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + + operation ApplyMiddle(op : Qubit => Unit, q : Qubit) : Unit { + ApplyInner(op, q); + } + + operation ApplyOuter(action : (Qubit => Unit, Qubit) => Unit, op : Qubit => Unit, q : Qubit) : Unit { + action(op, q); + } + + operation Main() : Unit { + use q = Qubit(); + ApplyOuter(ApplyMiddle, H, q); + } + "#, + &expect!["(no error)"], + ); + check( + r#" + operation ApplyInner(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + + operation ApplyMiddle(op : Qubit => Unit, q : Qubit) : Unit { + ApplyInner(op, q); + } + + operation ApplyOuter(action : (Qubit => Unit, Qubit) => Unit, op : Qubit => Unit, q : Qubit) : Unit { + action(op, q); + } + + operation Main() : Unit { + use q = Qubit(); + ApplyOuter(ApplyMiddle, H, q); + } + "#, + &expect![[r#" + ApplyInner{H}: input_ty=Qubit + ApplyMiddle{H}: input_ty=Qubit + ApplyOuter{ApplyMiddle}{H}: input_ty=Qubit + Main: input_ty=Unit"#]], + ); +} + +/// Regression test for producer-body closure cleanup: a producer function +/// that returns a partial-application closure causes convergence failure +/// when the closure node survives in the producer body after HOF +/// specialization. The closure cleanup pass must replace consumed closures +/// with Unit so that `remaining_callable_value_info` no longer counts them. +#[test] +fn producer_body_closure_cleanup_converges() { + check_invariants( + r#" + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation InnerOp(extra : Bool, q : Qubit) : Unit { + H(q); + } + function MakeOp(extra : Bool) : Qubit => Unit { + return InnerOp(extra, _); + } + operation Main() : Unit { + use q = Qubit(); + let op = MakeOp(true); + ApplyOp(op, q); + } + "#, + ); +} + +/// Two callable arguments passed to a multi-parameter HOF: one partial +/// application closure and one global callable. Both must survive cleanup +/// because they are still live as call arguments. +#[test] +fn closure_in_active_call_arg_survives_cleanup() { + check_invariants( + r#" + operation Apply2(f : Qubit => Unit, g : Qubit => Unit, q : Qubit) : Unit { + f(q); + g(q); + } + operation Inner(extra : Bool, q : Qubit) : Unit { + H(q); + } + operation Main() : Unit { + use q = Qubit(); + let op1 = Inner(true, _); + Apply2(op1, X, q); + } + "#, + ); +} + +/// When a mutable callable variable is reassigned in a loop, the analysis +/// resolves it to `Dynamic` (overdefined). The fixpoint loop detects no +/// progress — remaining callable count is unchanged and no new call sites are +/// discovered — and breaks via stuck detection. The `DynamicCallable` error +/// from the current iteration survives, preventing the post-loop +/// `FixpointNotReached` from firing (which only fires when `errors.is_empty()`). +#[test] +fn stuck_detection_with_unresolvable_callable_emits_dynamic_error() { + check_errors( + r#" + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation Main() : Unit { + use q = Qubit(); + mutable op = H; + for _ in 0..3 { + op = X; + } + ApplyOp(op, q); + } + "#, + &expect!["callable argument could not be resolved statically"], + ); +} + +/// Multi-level HOF chain where each fixpoint iteration resolves one level. +/// Confirms that the before/after progress tracking does not cause premature +/// exit when each iteration successfully reduces the remaining count. +#[test] +fn progress_tracking_allows_multi_iteration_convergence() { + check_invariants( + r#" + operation L1(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation L2(inner : (Qubit => Unit, Qubit) => Unit, op : Qubit => Unit, q : Qubit) : Unit { + inner(op, q); + } + operation L3(mid : ((Qubit => Unit, Qubit) => Unit, Qubit => Unit, Qubit) => Unit, inner : (Qubit => Unit, Qubit) => Unit, op : Qubit => Unit, q : Qubit) : Unit { + mid(inner, op, q); + } + operation Main() : Unit { + use q = Qubit(); + L3(L2, L1, H, q); + } + "#, + ); +} + +#[test] +fn pipeline_resolves_conditional_callable_binding() { + check_pipeline( + r#" + operation ApplyPower(power : Int, op : Qubit => Unit is Adj, target : Qubit) : Unit is Adj { + let u = if power >= 0 { op } else { Adjoint op }; + for _ in 1..power { + u(target); + } + } + + @EntryPoint() + operation Main() : Unit { + use q = Qubit(); + ApplyPower(3, S, q); + } + "#, + ); +} + +#[test] +fn pipeline_callable_from_tuple_destructured_array_iteration() { + check_pipeline( + r#" + namespace Test { + @EntryPoint() + operation Main() : Unit { + let arr = [(S, PauliZ), (T, PauliX)]; + for (op, _basis) in arr { + use q = Qubit(); + op(q); + } + } + } + "#, + ); +} + +#[test] +fn pipeline_teleportation_pattern_callable_from_array_of_tuples() { + check_pipeline( + r#" + namespace Test { + operation SetToPlus(q : Qubit) : Unit is Adj + Ctl { + H(q); + } + operation SetToMinus(q : Qubit) : Unit is Adj + Ctl { + X(q); + H(q); + } + + @EntryPoint() + operation Main() : Unit { + let ops = [ + (I, PauliZ), + (X, PauliZ), + (SetToPlus, PauliX), + (SetToMinus, PauliX), + ]; + for (initializer, _basis) in ops { + use q = Qubit(); + initializer(q); + } + } + } + "#, + ); +} + +#[test] +fn pipeline_callable_at_middle_of_three_tuple_from_array_iteration() { + check_pipeline( + r#" + namespace Test { + operation SetToPlus(q : Qubit) : Unit is Adj + Ctl { + H(q); + } + operation SetToMinus(q : Qubit) : Unit is Adj + Ctl { + X(q); + H(q); + } + + @EntryPoint() + operation Main() : Unit { + let ops = [ + (PauliZ, I, false), + (PauliZ, X, false), + (PauliX, SetToPlus, true), + (PauliX, SetToMinus, true), + ]; + for (_basis, initializer, _flag) in ops { + use q = Qubit(); + initializer(q); + } + } + } + "#, + ); +} + +#[test] +fn pipeline_teleportation_like_callable_from_string_tagged_triple_array() { + check_pipeline( + r#" + namespace Test { + operation SetToPlus(q : Qubit) : Unit is Adj + Ctl { + H(q); + } + operation SetToMinus(q : Qubit) : Unit is Adj + Ctl { + X(q); + H(q); + } + + @EntryPoint() + operation Main() : Unit { + let ops = [ + (I, PauliZ), + (X, PauliZ), + (SetToPlus, PauliX), + (SetToMinus, PauliX), + ]; + for (initializer, basis) in ops { + use q = Qubit(); + initializer(q); + let _ = Measure([basis], [q]); + Reset(q); + } + } + } + "#, + ); +} + +#[test] +fn pipeline_callable_array_iteration_exceeding_old_multi_cap() { + check_pipeline( + r#" + namespace Test { + operation SX(q : Qubit) : Unit is Adj + Ctl { + Rx(Microsoft.Quantum.Math.PI() / 2.0, q); + } + + @EntryPoint() + operation Main() : Unit { + let gates = [H, X, Y, Z, S, Adjoint S, SX]; + use q = Qubit(); + for gate in gates { + gate(q); + } + } + } + "#, + ); +} + +fn nested_hof_source(level_count: usize) -> String { + assert!(level_count > 0); + + let mut source = String::new(); + source.push_str("operation Level01(op : Qubit => Unit, q : Qubit) : Unit {\n op(q);\n}\n"); + + for level in 2..=level_count { + write!( + &mut source, + "operation Level{level:02}(op : Qubit => Unit, q : Qubit) : Unit {{\n Level{previous:02}(op, q);\n}}\n", + previous = level - 1, + ).expect("failed to write source string"); + } + + write!( + &mut source, + "@EntryPoint()\noperation Main() : Unit {{\n use q = Qubit();\n Level{level_count:02}(H, q);\n}}\n" + ).expect("failed to write source string"); + source +} + +#[test] +fn defunc_20_level_hof_returns_fixpoint_reached() { + // Regression test: 20-level HOF nesting is under the convergence cap. + let source = nested_hof_source(20); + + let (mut fir_store, fir_pkg_id) = crate::test_utils::compile_to_monomorphized_fir(&source); + let mut assigner = qsc_fir::assigner::Assigner::from_package(fir_store.get(fir_pkg_id)); + let errors = super::super::defunctionalize(&mut fir_store, fir_pkg_id, &mut assigner); + + assert!( + errors.is_empty(), + "Expected defunctionalization to succeed for 20-level HOF, got: {:?}", + errors.iter().map(ToString::to_string).collect::>() + ); +} + +#[test] +fn defunc_21_level_hof_returns_static_resolution_error() { + // Regression test: 21-level HOF nesting exceeds the current static + // resolution depth, but still reports a defunctionalization diagnostic + // instead of panicking or lowering invalid FIR. + let source = nested_hof_source(21); + + let (mut fir_store, fir_pkg_id) = crate::test_utils::compile_to_monomorphized_fir(&source); + let mut assigner = qsc_fir::assigner::Assigner::from_package(fir_store.get(fir_pkg_id)); + let errors = super::super::defunctionalize(&mut fir_store, fir_pkg_id, &mut assigner); + + assert!( + !errors.is_empty(), + "Expected defunctionalization error for 21-level HOF" + ); + + assert!( + matches!(errors.as_slice(), [super::super::Error::DynamicCallable(_)]), + "Expected DynamicCallable error, got: {:?}", + errors.iter().map(ToString::to_string).collect::>() + ); +} diff --git a/source/compiler/qsc_fir_transforms/src/defunctionalize/tests/invariants.rs b/source/compiler/qsc_fir_transforms/src/defunctionalize/tests/invariants.rs new file mode 100644 index 0000000000..1ed2d25983 --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/defunctionalize/tests/invariants.rs @@ -0,0 +1,388 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use super::*; +use expect_test::expect; + +#[test] +fn invariants_single_hof() { + check_invariants( + r#" + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation Main() : Unit { + use q = Qubit(); + ApplyOp(H, q); + } + "#, + ); +} + +#[test] +fn invariants_closure_with_captures() { + check_invariants( + r#" + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation Main() : Unit { + use q = Qubit(); + let angle = 1.0; + ApplyOp(q1 => Rx(angle, q1), q); + } + "#, + ); +} + +#[test] +fn invariants_functor_composition() { + check_invariants( + r#" + operation ApplyAdj(op : Qubit => Unit is Adj, q : Qubit) : Unit { + Adjoint op(q); + } + operation Main() : Unit { + use q = Qubit(); + ApplyAdj(S, q); + } + "#, + ); +} + +#[test] +fn error_dynamic_callable() { + check_invariants( + r#" + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation Main() : Unit { + use q = Qubit(); + mutable op = H; + if true { set op = X; } + ApplyOp(op, q); + } + "#, + ); +} + +#[test] +fn branch_split_resolves_mutable_callable() { + check_invariants( + r#" + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation Main() : Unit { + use q = Qubit(); + mutable f = X; + if true { set f = H; } else { set f = S; } + ApplyOp(f, q); + } + "#, + ); +} + +#[test] +fn branch_split_resolves_conditional_binding() { + check_invariants( + r#" + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation Main() : Unit { + use q = Qubit(); + let f = if true { H } else { X }; + ApplyOp(f, q); + } + "#, + ); +} + +#[test] +fn error_returned_not_panicked() { + let (mut store, package_id) = compile_to_monomorphized_fir( + r#" + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation Main() : Unit { + use q = Qubit(); + mutable op = H; + for _ in 0..3 { set op = X; } + ApplyOp(op, q); + } + "#, + ); + let mut assigner = qsc_fir::assigner::Assigner::from_package(store.get(package_id)); + let errors = defunctionalize(&mut store, package_id, &mut assigner); + assert!( + !errors.is_empty(), + "expected errors to be returned, not a panic" + ); +} + +#[test] +fn error_multiple_dynamic_sites_collected() { + let (mut store, package_id) = compile_to_monomorphized_fir( + r#" + operation Apply1(op : Qubit => Unit, q : Qubit) : Unit { op(q); } + operation Apply2(op : Qubit => Unit, q : Qubit) : Unit { op(q); } + operation Main() : Unit { + use q = Qubit(); + mutable f = H; + for _ in 0..3 { set f = X; } + Apply1(f, q); + mutable g = X; + for _ in 0..3 { set g = H; } + Apply2(g, q); + } + "#, + ); + let mut assigner = qsc_fir::assigner::Assigner::from_package(store.get(package_id)); + let errors = defunctionalize(&mut store, package_id, &mut assigner); + assert_eq!( + errors.len(), + 2, + "expected both dynamic callable sites to be collected" + ); + for error in &errors { + assert!( + matches!(error, super::super::Error::DynamicCallable(_)), + "expected DynamicCallable error, got {error:?}" + ); + assert!( + !error.to_string().is_empty(), + "each error should have a display message" + ); + } +} + +#[test] +fn nested_hof_call_chain_passes_invariants() { + check_invariants( + r#" + operation ApplyInner(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation ApplyOuter(op : Qubit => Unit, q : Qubit) : Unit { + ApplyInner(op, q); + } + operation Main() : Unit { + use q = Qubit(); + ApplyOuter(H, q); + } + "#, + ); +} + +#[test] +fn hof_inside_for_loop_passes_invariants() { + check_invariants( + r#" + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation Main() : Unit { + use q = Qubit(); + for _ in 0..3 { + ApplyOp(H, q); + } + } + "#, + ); +} + +#[test] +fn function_callable_argument_defunctionalizes() { + check_invariants( + r#" + function ApplyFn(f : Int -> Int, x : Int) : Int { + f(x) + } + function Double(x : Int) : Int { x * 2 } + @EntryPoint() + operation Main() : Unit { + let _ = ApplyFn(Double, 5); + } + "#, + ); +} + +#[test] +fn explicit_functor_specializations_defunctionalize() { + check_invariants( + r#" + operation ApplyOp(op : Qubit => Unit is Adj + Ctl, q : Qubit) : Unit is Adj + Ctl { + body ... { op(q); } + adjoint ... { Adjoint op(q); } + controlled (ctls, ...) { Controlled op(ctls, q); } + controlled adjoint (ctls, ...) { Controlled Adjoint op(ctls, q); } + } + operation Main() : Unit { + use q = Qubit(); + ApplyOp(S, q); + } + "#, + ); +} + +#[test] +fn full_pipeline_preserves_post_all_invariants() { + check_pipeline( + r#" + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation Main() : Unit { + use q = Qubit(); + ApplyOp(H, q); + ApplyOp(X, q); + let angle = 1.0; + ApplyOp(q1 => Rx(angle, q1), q); + } + "#, + ); +} + +#[test] +fn invariant_no_closures_remain() { + check_invariants( + r#" + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation Main() : Unit { + use q = Qubit(); + ApplyOp(q1 => H(q1), q); + } + "#, + ); +} + +#[test] +fn invariant_no_arrow_params_remain() { + check_invariants( + r#" + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation Main() : Unit { + use q = Qubit(); + ApplyOp(H, q); + ApplyOp(X, q); + } + "#, + ); +} + +#[test] +fn invariant_no_closures_after_full_defunc() { + check_invariants( + r#" + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { op(q); } + operation Main() : Unit { + use q = Qubit(); + let angle = 1.0; + ApplyOp(q1 => Rx(angle, q1), q); + ApplyOp(H, q); + } + "#, + ); +} + +#[test] +fn five_branch_conditional_callable_resolves_successfully() { + check_invariants( + r#" + operation Apply(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + + operation Main() : Unit { + use q = Qubit(); + let n = 2; + mutable op = H; + if n == 0 { + op = X; + } elif n == 1 { + op = Y; + } elif n == 2 { + op = Z; + } elif n == 3 { + op = S; + } else { + op = T; + } + Apply(op, q); + } + "#, + ); +} + +#[test] +fn nine_branch_conditional_callable_degrades_to_dynamic() { + check_errors( + r#" + operation Apply(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + + operation Main() : Unit { + use q = Qubit(); + let n = 2; + mutable op = H; + if n == 0 { + op = X; + } elif n == 1 { + op = Y; + } elif n == 2 { + op = Z; + } elif n == 3 { + op = S; + } elif n == 4 { + op = T; + } elif n == 5 { + op = Rx(0.0, _); + } elif n == 6 { + op = Ry(0.0, _); + } elif n == 7 { + op = Rz(0.0, _); + } else { + op = SWAP(_, q); + } + Apply(op, q); + } + "#, + &expect!["callable argument could not be resolved statically"], + ); +} + +#[test] +fn controlled_functor_count_saturates_without_overflow() { + check_invariants( + r#" + operation Foo(q : Qubit) : Unit is Ctl { + body ... { H(q); } + controlled (cs, ...) { Controlled H(cs, q); } + } + operation ApplyCtl1(q : Qubit, c1 : Qubit) : Unit { + Controlled Foo([c1], q); + } + operation ApplyCtl2(q : Qubit, c1 : Qubit, c2 : Qubit) : Unit { + Controlled Foo([c1, c2], q); + } + operation ApplyCtl3(q : Qubit, c1 : Qubit, c2 : Qubit, c3 : Qubit) : Unit { + Controlled Foo([c1, c2, c3], q); + } + @EntryPoint() + operation Main() : Unit { + use (q, c1, c2, c3) = (Qubit(), Qubit(), Qubit(), Qubit()); + ApplyCtl1(q, c1); + ApplyCtl2(q, c1, c2); + ApplyCtl3(q, c1, c2, c3); + } + "#, + ); +} diff --git a/source/compiler/qsc_fir_transforms/src/defunctionalize/tests/prepass.rs b/source/compiler/qsc_fir_transforms/src/defunctionalize/tests/prepass.rs new file mode 100644 index 0000000000..5512bd9ae8 --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/defunctionalize/tests/prepass.rs @@ -0,0 +1,730 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! Tests for the defunctionalization pre-pass rewrites. +//! +//! The pre-pass runs two key optimizations before collecting call sites: +//! 1. Promotes single-use immutable callable locals to direct item references +//! 2. Replaces identity closures `(args) => f(args)` with direct references to `f` + +use super::*; +use expect_test::expect; + +mod single_use_callable_local_promotion { + use super::*; + + /// Single-use callable local with simple item reference should be promoted. + #[test] + fn promote_simple_item_reference() { + check( + r#" + operation Main() : Unit { + use q = Qubit(); + let op = H; + op(q); + } + "#, + &expect![[r#" + Main: input_ty=Unit"#]], + ); + } + + /// Single-use callable local in HOF call should be promoted. + #[test] + fn promote_single_use_in_hof_call() { + check( + r#" + operation ApplyOp(f : Qubit => Unit, q : Qubit) : Unit { + f(q); + } + operation Main() : Unit { + use q = Qubit(); + let op = H; + ApplyOp(op, q); + } + "#, + &expect![[r#" + ApplyOp{H}: input_ty=Qubit + Main: input_ty=Unit"#]], + ); + } + + /// Multiple-use callable local still resolves through the later analysis. + #[test] + fn multiple_use_callable_local_resolves() { + check( + r#" + operation ApplyOp(f : Qubit => Unit, q : Qubit) : Unit { + f(q); + } + operation Main() : Unit { + use q = Qubit(); + let op = H; + ApplyOp(op, q); + ApplyOp(op, q); + } + "#, + &expect![[r#" + ApplyOp{H}: input_ty=Qubit + Main: input_ty=Unit"#]], + ); + } + + /// Callable local captured by an identity closure still resolves to its item. + #[test] + fn callable_local_captured_by_identity_closure_resolves() { + check( + r#" + operation ApplyOp(f : Qubit => Unit, q : Qubit) : Unit { + f(q); + } + operation Main() : Unit { + use q = Qubit(); + let op = H; + ApplyOp(q1 => op(q1), q); + } + "#, + &expect![[r#" + ApplyOp{H}: input_ty=Qubit + Main: input_ty=Unit"#]], + ); + } + + /// Mutable callable local with a static value still resolves through analysis. + #[test] + fn mutable_callable_local_resolves() { + check( + r#" + operation ApplyOp(f : Qubit => Unit, q : Qubit) : Unit { + f(q); + } + operation Main() : Unit { + use q = Qubit(); + mutable op = H; + ApplyOp(op, q); + } + "#, + &expect![[r#" + ApplyOp{H}: input_ty=Qubit + Main: input_ty=Unit"#]], + ); + } + + /// Callable local with identity-closure initializer should be simplified. + #[test] + fn callable_local_with_identity_closure_initializer_resolves() { + check( + r#" + operation ApplyOp(f : Qubit => Unit, q : Qubit) : Unit { + f(q); + } + operation Main() : Unit { + use q = Qubit(); + let op = q1 => H(q1); + ApplyOp(op, q); + } + "#, + &expect![[r#" + ApplyOp{H}: input_ty=Qubit + Main: input_ty=Unit"#]], + ); + } + + /// Callable local with a partial-application initializer resolves through closure lifting. + #[test] + fn no_promote_partial_application_initializer_resolves() { + check( + r#" + operation ApplyOp(f : Qubit => Unit, q : Qubit) : Unit { + f(q); + } + operation Parametrized(angle : Double, q : Qubit) : Unit { + Rz(angle, q); + } + operation Main() : Unit { + use q = Qubit(); + let angle = 0.5; + let op = Parametrized(angle, _); + ApplyOp(op, q); + } + "#, + &expect![[r#" + : input_ty=(Double, Qubit) + ApplyOp{closure}: input_ty=(Qubit, Double) + Main: input_ty=Unit + Parametrized: input_ty=(Double, Qubit)"#]], + ); + } + + /// Single-use callable local in nested scope should be promoted. + #[test] + fn promote_in_nested_scope() { + check( + r#" + operation ApplyOp(f : Qubit => Unit, q : Qubit) : Unit { + f(q); + } + operation Main() : Unit { + use q = Qubit(); + if true { + let op = H; + ApplyOp(op, q); + } + } + "#, + &expect![[r#" + ApplyOp{H}: input_ty=Qubit + Main: input_ty=Unit"#]], + ); + } + + /// Unused callable local (zero uses) is irrelevant but shouldn't cause issues. + #[test] + fn no_promote_zero_uses() { + check( + r#" + operation Main() : Unit { + use q = Qubit(); + let op = H; + () + } + "#, + &expect![[r#" + Main: input_ty=Unit"#]], + ); + } + + /// Single-use callable local with non-callable type should NOT be promoted. + #[test] + fn no_promote_non_callable_type() { + check( + r#" + operation Main() : Unit { + use q = Qubit(); + let x = 42; + let y = x; + } + "#, + &expect![[r#" + Main: input_ty=Unit"#]], + ); + } +} + +mod identity_closure_peephole_optimization { + use super::*; + + /// Basic identity closure `(q) => H(q)` should be replaced with `H`. + #[test] + fn identity_closure_basic() { + check( + r#" + operation ApplyOp(f : Qubit => Unit, q : Qubit) : Unit { + f(q); + } + operation Main() : Unit { + use q = Qubit(); + ApplyOp(q1 => H(q1), q); + } + "#, + &expect![[r#" + ApplyOp{H}: input_ty=Qubit + Main: input_ty=Unit"#]], + ); + } + + /// Identity closure with multiple parameters should be replaced. + #[test] + fn identity_closure_multiple_params() { + check( + r#" + operation ApplyTwo(f : (Qubit, Qubit) => Unit, q1 : Qubit, q2 : Qubit) : Unit { + f(q1, q2); + } + operation Main() : Unit { + use q1 = Qubit(); + use q2 = Qubit(); + ApplyTwo((control, target) => CNOT(control, target), q1, q2); + } + "#, + &expect![[r#" + ApplyTwo{CNOT}: input_ty=(Qubit, Qubit) + Main: input_ty=Unit"#]], + ); + } + + /// Identity closure with captured variable should be replaced. + #[test] + fn identity_closure_with_capture() { + check( + r#" + operation ApplyOp(f : Qubit => Unit, q : Qubit) : Unit { + f(q); + } + operation Main() : Unit { + use q = Qubit(); + let myH = H; + ApplyOp(q1 => myH(q1), q); + } + "#, + &expect![[r#" + ApplyOp{H}: input_ty=Qubit + Main: input_ty=Unit"#]], + ); + } + + /// Adjoint identity closure `(q) => Adjoint H(q)` should be optimized. + #[test] + fn identity_closure_adjoint() { + check( + r#" + operation ApplyOp(f : Qubit => Unit is Adj, q : Qubit) : Unit { + f(q); + } + operation Main() : Unit { + use q = Qubit(); + ApplyOp(q1 => Adjoint H(q1), q); + } + "#, + &expect![[r#" + ApplyOp{Adj H}: input_ty=Qubit + Main: input_ty=Unit"#]], + ); + } + + /// Controlled identity closure `(q) => Controlled X([], q)` should be optimized. + #[test] + fn identity_closure_controlled() { + check( + r#" + operation ApplyOp(f : (Qubit[], Qubit) => Unit is Ctl, q : Qubit) : Unit { + f([], q); + } + operation Main() : Unit { + use q = Qubit(); + ApplyOp((ctrls, tgt) => Controlled X(ctrls, tgt), q); + } + "#, + &expect![[r#" + ApplyOp{Ctl X}: input_ty=Qubit + Main: input_ty=Unit"#]], + ); + } + + /// Non-identity closure should NOT be optimized (argument reordering). + #[test] + fn no_optimize_reordered_args() { + check( + r#" + operation ApplyTwo(f : (Qubit, Qubit) => Unit, q1 : Qubit, q2 : Qubit) : Unit { + f(q1, q2); + } + operation Main() : Unit { + use q1 = Qubit(); + use q2 = Qubit(); + ApplyTwo((a, b) => H(b), q1, q2); + } + "#, + &expect![[r#" + : input_ty=((Qubit, Qubit),) + ApplyTwo{closure}: input_ty=(Qubit, Qubit) + Main: input_ty=Unit"#]], + ); + } + + /// Non-identity closure with capture in args should NOT be optimized. + #[test] + fn no_optimize_capture_in_args() { + check( + r#" + operation ApplyOp(f : Qubit => Unit, q : Qubit) : Unit { + f(q); + } + operation Main() : Unit { + use q = Qubit(); + let myQ = q; + ApplyOp(q1 => H(myQ), q); + } + "#, + &expect![[r#" + : input_ty=(Qubit, Qubit) + ApplyOp{closure}: input_ty=(Qubit, Qubit) + Main: input_ty=Unit"#]], + ); + } + + /// Closure that does not forward its parameter should NOT be optimized. + #[test] + fn no_optimize_non_forwarded_param() { + check( + r#" + operation ApplyOp(f : (Unit => Unit), _ : Unit) : Unit { + f(()); + } + operation Main() : Unit { + use other = Qubit(); + ApplyOp(u => H(other), ()); + Reset(other); + } + "#, + &expect![[r#" + : input_ty=(Qubit, Unit) + ApplyOp{closure}: input_ty=(Unit, Qubit) + Main: input_ty=Unit"#]], + ); + } + + /// Closure with multiple statements should NOT be optimized (not identity). + #[test] + fn no_optimize_multiple_statements() { + check( + r#" + operation ApplyOp(f : Qubit => Unit, q : Qubit) : Unit { + f(q); + } + operation Main() : Unit { + use q = Qubit(); + ApplyOp(q1 => { H(q1); X(q1) }, q); + } + "#, + &expect![[r#" + : input_ty=(Qubit,) + ApplyOp{closure}: input_ty=Qubit + Main: input_ty=Unit"#]], + ); + } + + /// Closure body that's not a call should NOT be optimized. + #[test] + fn no_optimize_non_call_body() { + check( + r#" + operation ApplyOp(f : Qubit => Int, q : Qubit) : Int { + f(q) + } + operation Main() : Unit { + use q = Qubit(); + let result = ApplyOp(q1 => 42, q); + } + "#, + &expect![[r#" + : input_ty=(Qubit,) + ApplyOp{closure}: input_ty=Qubit + Main: input_ty=Unit"#]], + ); + } +} + +mod combined_promotion_and_peephole_optimizations { + use super::*; + + /// Single-use local with identity closure should both be optimized. + #[test] + fn combined_promotion_and_identity_closure() { + check( + r#" + operation ApplyOp(f : Qubit => Unit, q : Qubit) : Unit { + f(q); + } + operation Main() : Unit { + use q = Qubit(); + let op = q1 => H(q1); + ApplyOp(op, q); + } + "#, + &expect![[r#" + ApplyOp{H}: input_ty=Qubit + Main: input_ty=Unit"#]], + ); + } + + /// Multiple single-use locals with identity closures. + #[test] + fn multiple_promoted_identity_closures() { + check( + r#" + operation ApplyOp(f : Qubit => Unit, q : Qubit) : Unit { + f(q); + } + operation Main() : Unit { + use q = Qubit(); + let op1 = q1 => H(q1); + let op2 = q1 => X(q1); + ApplyOp(op1, q); + ApplyOp(op2, q); + } + "#, + &expect![[r#" + ApplyOp{H}: input_ty=Qubit + ApplyOp{X}: input_ty=Qubit + Main: input_ty=Unit"#]], + ); + } + + /// Promoted local used in identity closure. + #[test] + fn promoted_local_in_identity_closure() { + check( + r#" + operation ApplyOp(f : Qubit => Unit, q : Qubit) : Unit { + f(q); + } + operation Main() : Unit { + use q = Qubit(); + let myH = H; + ApplyOp(q1 => myH(q1), q); + } + "#, + &expect![[r#" + ApplyOp{H}: input_ty=Qubit + Main: input_ty=Unit"#]], + ); + } +} + +mod edge_cases_and_complex_scenarios { + use super::*; + + /// Identity closure with adjoint and captured variable. + #[test] + fn identity_closure_adjoint_captured() { + check( + r#" + operation ApplyOp(f : Qubit => Unit is Adj, q : Qubit) : Unit { + f(q); + } + operation Main() : Unit { + use q = Qubit(); + let op = H; + ApplyOp(q1 => Adjoint op(q1), q); + } + "#, + &expect![[r#" + ApplyOp{Adj H}: input_ty=Qubit + Main: input_ty=Unit"#]], + ); + } + + /// Complex HOF with mixed promoted and identity closures. + #[test] + fn complex_hof_mixed_optimizations() { + check( + r#" + operation ApplyTwo(f : Qubit => Unit, g : Qubit => Unit, q : Qubit) : Unit { + f(q); + g(q); + } + operation Main() : Unit { + use q = Qubit(); + let op = H; + ApplyTwo(op, q1 => X(q1), q); + } + "#, + &expect![[r#" + ApplyTwo{H}{X}: input_ty=Qubit + Main: input_ty=Unit"#]], + ); + } + + /// Identity closure with parameter passed to a nested operation. + #[test] + fn identity_closure_param_to_nested_op() { + check( + r#" + operation Inner(f : Qubit => Unit, q : Qubit) : Unit { + f(q); + } + operation Outer(g : Qubit => Unit, q : Qubit) : Unit { + Inner(g, q); + } + operation Main() : Unit { + use q = Qubit(); + Outer(q1 => H(q1), q); + } + "#, + &expect![[r#" + Inner{H}: input_ty=Qubit + Main: input_ty=Unit + Outer{H}: input_ty=Qubit"#]], + ); + } + + /// Single-use callable local assigned from another single-use callable local (chain). + #[test] + fn promoted_local_chain() { + check( + r#" + operation ApplyOp(f : Qubit => Unit, q : Qubit) : Unit { + f(q); + } + operation Main() : Unit { + use q = Qubit(); + let op1 = H; + let op2 = op1; + ApplyOp(op2, q); + } + "#, + &expect![[r#" + ApplyOp{H}: input_ty=Qubit + Main: input_ty=Unit"#]], + ); + } + + /// Identity closure capturing a single-use promoted local. + #[test] + fn identity_closure_captures_promoted_local() { + check( + r#" + operation ApplyOp(f : Qubit => Unit, q : Qubit) : Unit { + f(q); + } + operation Main() : Unit { + use q = Qubit(); + let myH = H; + let op = q1 => myH(q1); + ApplyOp(op, q); + } + "#, + &expect![[r#" + ApplyOp{H}: input_ty=Qubit + Main: input_ty=Unit"#]], + ); + } + + /// Intrinsic callable should not cause issues in identity closure detection. + #[test] + fn identity_closure_with_intrinsic() { + check( + r#" + operation ApplyOp(f : Qubit => Unit, q : Qubit) : Unit { + f(q); + } + operation Main() : Unit { + use q = Qubit(); + ApplyOp(q1 => H(q1), q); + } + "#, + &expect![[r#" + ApplyOp{H}: input_ty=Qubit + Main: input_ty=Unit"#]], + ); + } + + /// Callable local with discard pattern should NOT be promoted. + #[test] + fn no_promote_discard_pattern() { + check( + r#" + operation Main() : Unit { + use q = Qubit(); + let _ = H; + } + "#, + &expect![[r#" + Main: input_ty=Unit"#]], + ); + } + + /// Callable local with tuple destructuring still resolves through analysis. + #[test] + fn tuple_destructured_callable_local_resolves() { + check( + r#" + operation ApplyOp(f : Qubit => Unit, q : Qubit) : Unit { + f(q); + } + operation Main() : Unit { + use q = Qubit(); + let (op, _) = (H, X); + ApplyOp(op, q); + } + "#, + &expect![[r#" + ApplyOp{H}: input_ty=Qubit + Main: input_ty=Unit"#]], + ); + } +} + +mod parameter_extraction_and_validation_helpers { + use super::*; + + /// Identity closure with tuple of single parameters should work. + #[test] + fn identity_closure_tuple_params() { + check( + r#" + operation ApplyTwo(f : (Int, Qubit) => Unit, q : Qubit, n : Int) : Unit { + f(n, q); + } + operation UseIntQubit(i : Int, q : Qubit) : Unit { + if i == 42 { + H(q); + } + } + operation Main() : Unit { + use q = Qubit(); + let n = 42; + ApplyTwo((i, q1) => UseIntQubit(i, q1), q, n); + } + "#, + &expect![[r#" + ApplyTwo{UseIntQubit}: input_ty=(Qubit, Int) + Main: input_ty=Unit + UseIntQubit: input_ty=(Int, Qubit)"#]], + ); + } +} + +mod nested_function_scopes { + use super::*; + + /// Single-use callable local in nested function scope. + #[test] + fn promote_in_nested_function() { + check( + r#" + operation ApplyOp(f : Qubit => Unit, q : Qubit) : Unit { + f(q); + } + operation Outer() : Unit { + use q = Qubit(); + if true { + let op = H; + ApplyOp(op, q); + } + } + operation Main() : Unit { + Outer(); + } + "#, + &expect![[r#" + ApplyOp{H}: input_ty=Qubit + Main: input_ty=Unit + Outer: input_ty=Unit"#]], + ); + } + + /// Identity closure in nested function scope. + #[test] + fn identity_closure_nested_function() { + check( + r#" + operation ApplyOp(f : Qubit => Unit, q : Qubit) : Unit { + f(q); + } + operation Outer() : Unit { + use q = Qubit(); + ApplyOp(q1 => H(q1), q); + } + operation Main() : Unit { + Outer(); + } + "#, + &expect![[r#" + ApplyOp{H}: input_ty=Qubit + Main: input_ty=Unit + Outer: input_ty=Unit"#]], + ); + } +} diff --git a/source/compiler/qsc_fir_transforms/src/defunctionalize/tests/specialization.rs b/source/compiler/qsc_fir_transforms/src/defunctionalize/tests/specialization.rs new file mode 100644 index 0000000000..2af64da152 --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/defunctionalize/tests/specialization.rs @@ -0,0 +1,1135 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use super::*; +use expect_test::expect; + +#[test] +fn specialize_single_global_callable() { + check( + r#" + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation Main() : Unit { + use q = Qubit(); + ApplyOp(H, q); + } + "#, + &expect![[r#" + ApplyOp{H}: input_ty=Qubit + Main: input_ty=Unit"#]], + ); +} + +#[test] +fn specialize_two_different_callables() { + check( + r#" + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation Main() : Unit { + use q = Qubit(); + ApplyOp(H, q); + ApplyOp(X, q); + } + "#, + &expect![[r#" + ApplyOp{H}: input_ty=Qubit + ApplyOp{X}: input_ty=Qubit + Main: input_ty=Unit"#]], + ); +} + +#[test] +fn specialize_same_callable_reuse() { + check( + r#" + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation Main() : Unit { + use q = Qubit(); + ApplyOp(H, q); + ApplyOp(H, q); + } + "#, + &expect![[r#" + ApplyOp{H}: input_ty=Qubit + Main: input_ty=Unit"#]], + ); +} + +#[test] +fn specialize_no_hof_unchanged() { + check( + r#" + operation Foo(q : Qubit) : Unit { + H(q); + } + operation Main() : Unit { + use q = Qubit(); + Foo(q); + } + "#, + &expect![[r#" + Foo: input_ty=Qubit + Main: input_ty=Unit"#]], + ); +} + +#[test] +fn specialize_closure_no_captures() { + check( + r#" + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation Main() : Unit { + use q = Qubit(); + ApplyOp(q1 => H(q1), q); + } + "#, + &expect![[r#" + ApplyOp{H}: input_ty=Qubit + Main: input_ty=Unit"#]], + ); +} + +#[test] +fn specialize_closure_with_captures() { + check( + r#" + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation Main() : Unit { + use q = Qubit(); + let angle = 1.0; + ApplyOp(q1 => Rx(angle, q1), q); + } + "#, + &expect![[r#" + : input_ty=(Double, Qubit) + ApplyOp{closure}: input_ty=(Qubit, Double) + Main: input_ty=Unit"#]], + ); +} + +#[test] +fn specialize_closure_capture_types_preserved() { + check( + r#" + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation Main() : Unit { + use q = Qubit(); + let n = 3; + ApplyOp(q1 => { for _ in 0..n { H(q1); } }, q); + } + "#, + &expect![[r#" + : input_ty=(Int, Qubit) + ApplyOp{closure}: input_ty=(Qubit, Int) + Main: input_ty=Unit"#]], + ); +} + +#[test] +fn specialize_creation_site_adjoint() { + check( + r#" + operation ApplyOp(op : Qubit => Unit is Adj, q : Qubit) : Unit { + op(q); + } + operation Main() : Unit { + use q = Qubit(); + ApplyOp(Adjoint S, q); + } + "#, + &expect![[r#" + ApplyOp{Adj S}: input_ty=Qubit + Main: input_ty=Unit"#]], + ); +} + +#[test] +fn specialize_body_side_adjoint() { + check( + r#" + operation ApplyAdj(op : Qubit => Unit is Adj, q : Qubit) : Unit { + Adjoint op(q); + } + operation Main() : Unit { + use q = Qubit(); + ApplyAdj(S, q); + } + "#, + &expect![[r#" + ApplyAdj{S}: input_ty=Qubit + Main: input_ty=Unit"#]], + ); +} + +#[test] +fn specialize_double_adjoint_cancels() { + check( + r#" + operation ApplyAdj(op : Qubit => Unit is Adj, q : Qubit) : Unit { + Adjoint op(q); + } + operation Main() : Unit { + use q = Qubit(); + ApplyAdj(Adjoint S, q); + } + "#, + &expect![[r#" + ApplyAdj{Adj S}: input_ty=Qubit + Main: input_ty=Unit"#]], + ); +} + +#[test] +fn specialize_body_side_controlled() { + check( + r#" + operation ApplyCtl(op : Qubit => Unit is Ctl, ctl : Qubit, q : Qubit) : Unit { + Controlled op([ctl], q); + } + operation Main() : Unit { + use (ctl, q) = (Qubit(), Qubit()); + ApplyCtl(X, ctl, q); + } + "#, + &expect![[r#" + ApplyCtl{X}: input_ty=(Qubit, Qubit) + Main: input_ty=Unit"#]], + ); +} + +#[test] +fn specialize_body_controlled_adjoint_nested() { + check( + r#" + operation ApplyCtlAdj(op : Qubit => Unit is Adj + Ctl, ctl : Qubit, q : Qubit) : Unit { + Controlled Adjoint op([ctl], q); + } + operation Main() : Unit { + use (ctl, q) = (Qubit(), Qubit()); + ApplyCtlAdj(S, ctl, q); + } + "#, + &expect![[r#" + ApplyCtlAdj{S}: input_ty=(Qubit, Qubit) + Main: input_ty=Unit"#]], + ); +} + +#[test] +fn specialize_creation_adjoint_body_controlled() { + check( + r#" + operation ApplyCtl(op : Qubit => Unit is Adj + Ctl, ctl : Qubit, q : Qubit) : Unit { + Controlled op([ctl], q); + } + operation Main() : Unit { + use (ctl, q) = (Qubit(), Qubit()); + ApplyCtl(Adjoint S, ctl, q); + } + "#, + &expect![[r#" + ApplyCtl{Adj S}: input_ty=(Qubit, Qubit) + Main: input_ty=Unit"#]], + ); +} + +#[test] +fn specialize_hof_with_adj_autogen() { + check( + r#" + operation ApplyOp(op : Qubit => Unit is Adj, q : Qubit) : Unit is Adj { + body ... { op(q); } + adjoint auto; + } + operation Main() : Unit { + use q = Qubit(); + ApplyOp(S, q); + Adjoint ApplyOp(S, q); + } + "#, + &expect![[r#" + ApplyOp{S}: input_ty=Qubit + Main: input_ty=Unit"#]], + ); +} + +#[test] +fn specialize_hof_with_ctl_autogen() { + check( + r#" + operation ApplyOp(op : Qubit => Unit is Ctl, q : Qubit) : Unit is Ctl { + body ... { op(q); } + controlled auto; + } + operation Main() : Unit { + use q = Qubit(); + ApplyOp(X, q); + } + "#, + &expect![[r#" + ApplyOp{X}: input_ty=Qubit + Main: input_ty=Unit"#]], + ); +} + +#[test] +fn specialize_hof_with_adj_ctl_autogen() { + check( + r#" + operation ApplyOp(op : Qubit => Unit is Adj + Ctl, q : Qubit) : Unit is Adj + Ctl { + body ... { op(q); } + adjoint auto; + controlled auto; + controlled adjoint auto; + } + operation Main() : Unit { + use (ctl, q) = (Qubit(), Qubit()); + ApplyOp(S, q); + } + "#, + &expect![[r#" + ApplyOp{S}: input_ty=Qubit + Main: input_ty=Unit"#]], + ); +} + +#[test] +fn specialize_single_assignment_local() { + check( + r#" + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation Main() : Unit { + use q = Qubit(); + let myH = H; + ApplyOp(myH, q); + } + "#, + &expect![[r#" + ApplyOp{H}: input_ty=Qubit + Main: input_ty=Unit"#]], + ); +} + +#[test] +fn defunctionalized_call_site_drops_callable_argument() { + let source = r#" + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation Main() : Unit { + use q = Qubit(); + ApplyOp(H, q); + } + "#; + check( + source, + &expect![[r#" + ApplyOp{H}: input_ty=Qubit + Main: input_ty=Unit"#]], + ); + assert_eq!( + call_arg_tuple_lengths_after_defunc(source, "ApplyOp{H}"), + vec![1], + "defunctionalized ApplyOp call should pass only the qubit argument" + ); +} + +#[test] +fn rewrite_closure_capture_args_inserted() { + let source = r#" + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation Main() : Unit { + use q = Qubit(); + let angle = 1.0; + ApplyOp(q1 => Rx(angle, q1), q); + } + "#; + check( + source, + &expect![[r#" + : input_ty=(Double, Qubit) + ApplyOp{closure}: input_ty=(Qubit, Double) + Main: input_ty=Unit"#]], + ); + assert_eq!( + call_arg_tuple_lengths_after_defunc(source, "ApplyOp{closure}"), + vec![2], + "rewritten closure call should pass the qubit and captured angle" + ); +} + +#[test] +fn multiple_callable_parameters_specialize_independently() { + check( + r#" + operation ApplyTwo(f : Qubit => Unit, g : Qubit => Unit, q : Qubit) : Unit { + f(q); + g(q); + } + operation Main() : Unit { + use q = Qubit(); + ApplyTwo(H, X, q); + } + "#, + &expect![[r#" + ApplyTwo{H}{X}: input_ty=Qubit + Main: input_ty=Unit"#]], + ); +} + +#[test] +fn capture_local_ids_are_reasonable() { + let (mut fir_store, fir_pkg_id) = compile_to_monomorphized_fir( + r#" + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation Main() : Unit { + use q = Qubit(); + let angle = 1.0; + ApplyOp(q1 => Rx(angle, q1), q); + } + "#, + ); + let mut assigner = qsc_fir::assigner::Assigner::from_package(fir_store.get(fir_pkg_id)); + let errors = defunctionalize(&mut fir_store, fir_pkg_id, &mut assigner); + assert_no_defunctionalization_errors("defunctionalization", &errors); + let package = fir_store.get(fir_pkg_id); + + for (_, pat) in &package.pats { + if let fir::PatKind::Bind(ident) = &pat.kind { + let id: u32 = ident.id.into(); + assert!( + id < 10_000, + "LocalVarId {id} is unreasonably large -- capture IDs should be sequential, not u32::MAX-based" + ); + } + } +} + +#[test] +fn pipeline_with_captures_no_sroa_panic() { + use crate::test_utils::{PipelineStage, compile_and_run_pipeline_to}; + + let (_store, _pkg_id) = compile_and_run_pipeline_to( + r#" + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation Main() : Unit { + use q = Qubit(); + let pair = (1.0, 2.0); + let (a, b) = pair; + ApplyOp(q1 => Rx(a + b, q1), q); + } + "#, + PipelineStage::Full, + ); +} + +#[test] +fn multiple_captures_sequential_ids() { + let (mut fir_store, fir_pkg_id) = compile_to_monomorphized_fir( + r#" + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation Main() : Unit { + use q = Qubit(); + let a = 1.0; + let b = 2.0; + let c = 3.0; + ApplyOp(q1 => { Rx(a, q1); Ry(b, q1); Rz(c, q1); }, q); + } + "#, + ); + let mut assigner = qsc_fir::assigner::Assigner::from_package(fir_store.get(fir_pkg_id)); + let errors = defunctionalize(&mut fir_store, fir_pkg_id, &mut assigner); + assert_no_defunctionalization_errors("defunctionalization", &errors); + let package = fir_store.get(fir_pkg_id); + + let mut capture_ids: Vec = Vec::new(); + for (_, pat) in &package.pats { + if let fir::PatKind::Bind(ident) = &pat.kind + && ident.name.starts_with("__capture_") + { + let id: u32 = ident.id.into(); + capture_ids.push(id); + } + } + + assert!( + capture_ids.len() >= 3, + "expected at least 3 capture bindings, found {}", + capture_ids.len() + ); + + for &id in &capture_ids { + assert!(id < 10_000, "capture LocalVarId {id} is unreasonably large"); + } + + capture_ids.sort_unstable(); + for window in capture_ids.windows(2) { + assert_eq!( + window[1] - window[0], + 1, + "capture IDs should be sequential, got {} and {}", + window[0], + window[1] + ); + } +} + +#[test] +fn specialize_closure_capturing_immutable_variable() { + check( + r#" + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { op(q); } + operation Main() : Unit { + use q = Qubit(); + let angle = 1.0; + ApplyOp(q1 => Rx(angle, q1), q); + } + "#, + &expect![[r#" + : input_ty=(Double, Qubit) + ApplyOp{closure}: input_ty=(Qubit, Double) + Main: input_ty=Unit"#]], + ); +} + +#[test] +fn specialize_closure_in_while_loop_body() { + check( + r#" + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { op(q); } + operation Main() : Unit { + use q = Qubit(); + mutable n = 3; + while n > 0 { + ApplyOp(q1 => H(q1), q); + n -= 1; + } + } + "#, + &expect![[r#" + ApplyOp{H}: input_ty=Qubit + Main: input_ty=Unit"#]], + ); +} + +#[test] +fn specialize_multiple_closures_same_signature() { + check( + r#" + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { op(q); } + operation Main() : Unit { + use q = Qubit(); + ApplyOp(q1 => H(q1), q); + ApplyOp(q1 => X(q1), q); + } + "#, + &expect![[r#" + ApplyOp{H}: input_ty=Qubit + ApplyOp{X}: input_ty=Qubit + Main: input_ty=Unit"#]], + ); +} + +#[test] +fn branch_split_two_callees() { + check_invariants( + r#" + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation Main() : Unit { + use q = Qubit(); + let f = if true { H } else { X }; + ApplyOp(f, q); + } + "#, + ); +} + +#[test] +fn branch_split_three_callees() { + check_invariants( + r#" + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation Main() : Unit { + use q = Qubit(); + let f = if true { H } elif false { X } else { S }; + ApplyOp(f, q); + } + "#, + ); +} + +#[test] +fn branch_split_mutable_conditional() { + check_invariants( + r#" + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation Main() : Unit { + use q = Qubit(); + mutable op = H; + if true { set op = X; } + ApplyOp(op, q); + } + "#, + ); +} + +#[test] +fn branch_split_nested_callable_in_tuple() { + check_invariants( + r#" + operation Wrapper(pair : (Qubit => Unit, Int), q : Qubit) : Unit { + let (op, _) = pair; + op(q); + } + operation Main() : Unit { + use q = Qubit(); + let f = if true { H } else { X }; + Wrapper((f, 42), q); + } + "#, + ); +} + +#[test] +fn branch_split_nested_callable_in_tuple_args_consistency() { + let (mut fir_store, fir_pkg_id) = compile_to_monomorphized_fir( + r#" + operation Wrapper(pair : (Qubit => Unit, Int), q : Qubit) : Unit { + let (op, _) = pair; + op(q); + } + operation Main() : Unit { + use q = Qubit(); + let f = if true { H } else { X }; + Wrapper((f, 42), q); + } + "#, + ); + let mut assigner = qsc_fir::assigner::Assigner::from_package(fir_store.get(fir_pkg_id)); + let errors = defunctionalize(&mut fir_store, fir_pkg_id, &mut assigner); + assert_no_defunctionalization_errors("defunctionalization", &errors); + let package = fir_store.get(fir_pkg_id); + + let mut mismatches = Vec::new(); + for (expr_id, expr) in &package.exprs { + if let fir::ExprKind::Call(_callee_id, args_id) = &expr.kind { + let args_expr = package.get_expr(*args_id); + if let fir::ExprKind::Tuple(elements) = &args_expr.kind + && let qsc_fir::ty::Ty::Tuple(type_elems) = &args_expr.ty + { + if elements.len() != type_elems.len() { + mismatches.push(format!( + "Call expr {expr_id}: args tuple has {} elements but type has {} elements", + elements.len(), + type_elems.len() + )); + } + for (i, (&elem_id, ty_elem)) in elements.iter().zip(type_elems.iter()).enumerate() { + let elem_expr = package.get_expr(elem_id); + let elem_is_tuple = matches!(elem_expr.kind, fir::ExprKind::Tuple(_)); + let ty_is_tuple = matches!(ty_elem, qsc_fir::ty::Ty::Tuple(_)); + if elem_is_tuple != ty_is_tuple { + mismatches.push(format!( + "Call expr {expr_id}: args[{i}] is_tuple={elem_is_tuple} but type is_tuple={ty_is_tuple} (elem_ty={}, type_elem={ty_elem})", + elem_expr.ty, + )); + } + } + } + } + } + assert!( + mismatches.is_empty(), + "Type/value mismatches in branch-split args:\n{}", + mismatches.join("\n") + ); +} + +#[test] +fn branch_split_nested_callable_full_pipeline() { + use crate::test_utils::{PipelineStage, compile_and_run_pipeline_to}; + + let (_store, _pkg_id) = compile_and_run_pipeline_to( + r#" + operation Wrapper(pair : (Qubit => Unit, Int), q : Qubit) : Unit { + let (op, _) = pair; + op(q); + } + operation Main() : Unit { + use q = Qubit(); + let f = if true { H } else { X }; + Wrapper((f, 42), q); + } + "#, + PipelineStage::Full, + ); +} + +#[test] +fn specialize_nested_callable_first_element() { + check( + r#" + operation Wrapper(pair : (Qubit => Unit, Int), q : Qubit) : Unit { + let (op, _) = pair; + op(q); + } + operation Main() : Unit { + use q = Qubit(); + Wrapper((H, 42), q); + } + "#, + &expect![[r#" + Main: input_ty=Unit + Wrapper{H}: input_ty=(Int, Qubit)"#]], + ); +} + +#[test] +fn specialize_nested_callable_second_element() { + check( + r#" + operation Wrapper(pair : (Int, Qubit => Unit), q : Qubit) : Unit { + let (_, op) = pair; + op(q); + } + operation Main() : Unit { + use q = Qubit(); + Wrapper((42, H), q); + } + "#, + &expect![[r#" + Main: input_ty=Unit + Wrapper{H}: input_ty=(Int, Qubit)"#]], + ); +} + +#[test] +fn specialize_nested_callable_both_fields_used() { + check( + r#" + operation Wrapper(pair : (Qubit => Unit, Int), q : Qubit) : Unit { + let (op, n) = pair; + op(q); + let _ = n; + } + operation Main() : Unit { + use q = Qubit(); + Wrapper((H, 42), q); + } + "#, + &expect![[r#" + Main: input_ty=Unit + Wrapper{H}: input_ty=(Int, Qubit)"#]], + ); +} + +#[test] +fn specialize_nested_callable_transitive_alias() { + check( + r#" + operation Wrapper(pair : (Qubit => Unit, Int), q : Qubit) : Unit { + let (op, _) = pair; + let f = op; + f(q); + } + operation Main() : Unit { + use q = Qubit(); + Wrapper((H, 42), q); + } + "#, + &expect![[r#" + Main: input_ty=Unit + Wrapper{H}: input_ty=(Int, Qubit)"#]], + ); +} + +#[test] +fn specialize_nested_callable_invariants() { + check_invariants( + r#" + operation Wrapper(pair : (Qubit => Unit, Int), q : Qubit) : Unit { + let (op, _) = pair; + op(q); + } + operation Main() : Unit { + use q = Qubit(); + Wrapper((H, 42), q); + } + "#, + ); +} + +#[test] +fn specialize_nested_callable_full_pipeline() { + use crate::test_utils::{PipelineStage, compile_and_run_pipeline_to}; + + let (_store, _pkg_id) = compile_and_run_pipeline_to( + r#" + operation Wrapper(pair : (Qubit => Unit, Int), q : Qubit) : Unit { + let (op, _) = pair; + op(q); + } + operation Main() : Unit { + use q = Qubit(); + Wrapper((H, 42), q); + } + "#, + PipelineStage::Full, + ); +} + +#[test] +fn branch_split_nested_callable_adj_ctl_args_consistency() { + let (mut fir_store, fir_pkg_id) = compile_to_monomorphized_fir( + r#" + operation Op1(q : Qubit) : Unit is Adj + Ctl { H(q); } + operation Op2(q : Qubit) : Unit is Adj + Ctl { X(q); } + operation Wrapper(pair : (Qubit => Unit is Adj + Ctl, Int), q : Qubit) : Unit { + let (op, _) = pair; + op(q); + } + operation Main() : Unit { + use q = Qubit(); + let b = true; + let f = if b { Op1 } else { Op2 }; + Wrapper((f, 42), q); + } + "#, + ); + let mut assigner = qsc_fir::assigner::Assigner::from_package(fir_store.get(fir_pkg_id)); + let errors = defunctionalize(&mut fir_store, fir_pkg_id, &mut assigner); + assert_no_defunctionalization_errors("defunctionalization", &errors); + let package = fir_store.get(fir_pkg_id); + + let mut mismatches = Vec::new(); + for (expr_id, expr) in &package.exprs { + if let fir::ExprKind::Call(_callee_id, args_id) = &expr.kind { + let args_expr = package.get_expr(*args_id); + if let fir::ExprKind::Tuple(elements) = &args_expr.kind + && let qsc_fir::ty::Ty::Tuple(type_elems) = &args_expr.ty + && elements.len() != type_elems.len() + { + mismatches.push(format!( + "Call expr {expr_id}: args tuple has {} elements but type has {} elements", + elements.len(), + type_elems.len() + )); + } + } + } + assert!( + mismatches.is_empty(), + "Type/value mismatches in branch-split args:\n{}", + mismatches.join("\n") + ); +} + +#[test] +fn closure_with_multiple_captures_threads_all_captures() { + check( + r#" + operation Apply(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + + operation Main() : Unit { + use q = Qubit(); + let angle1 = 1.0; + let angle2 = 2.0; + let myOp = (q) => { Rx(angle1, q); Ry(angle2, q); }; + Apply(myOp, q); + } + "#, + &expect![[r#" + : input_ty=(Double, Double, Qubit) + Apply{closure}: input_ty=(Qubit, Double, Double) + Main: input_ty=Unit"#]], + ); +} + +#[test] +fn single_param_tuple_containing_arrow_specializes_end_to_end() { + check( + r#" + operation Apply(pair : (Qubit => Unit, Qubit)) : Unit { + let (op, q) = pair; + op(q); + } + @EntryPoint() + operation Main() : Unit { + use q = Qubit(); + Apply((H, q)); + } + "#, + &expect![[r#" + Apply{H}: input_ty=Qubit + Main: input_ty=Unit"#]], + ); +} + +#[test] +fn single_param_tuple_second_element_specializes_end_to_end() { + check( + r#" + operation Wrapper(pair : (Int, Qubit => Unit)) : Unit { + let (_, op) = pair; + use q = Qubit(); + op(q); + } + operation Main() : Unit { + Wrapper((42, H)); + } + "#, + &expect![[r#" + Main: input_ty=Unit + Wrapper{H}: input_ty=Int"#]], + ); +} + +#[test] +fn single_param_recursive_tuple_callable_specializes_end_to_end() { + check( + r#" + operation Wrapper(bundle : (((Qubit => Unit, Int), Double), Qubit)) : Unit { + let (((op, n), angle), q) = bundle; + let _ = n; + let _ = angle; + op(q); + } + operation Main() : Unit { + use q = Qubit(); + Wrapper((((H, 42), 1.0), q)); + } + "#, + &expect![[r#" + Main: input_ty=Unit + Wrapper{H}: input_ty=((Int, Double), Qubit)"#]], + ); +} + +#[test] +fn single_param_recursive_tuple_callable_closure_capture_invariants() { + check_invariants( + r#" + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation Wrapper(bundle : (((Qubit => Unit, Int), Double), Qubit)) : Unit { + let (((op, n), angle), q) = bundle; + ApplyOp( + q1 => { + if n == 0 { + Rx(angle, q1); + } + op(q1); + }, + q + ); + } + operation Main() : Unit { + use q = Qubit(); + Wrapper((((H, 0), 1.0), q)); + } + "#, + ); +} + +#[test] +fn three_branch_conditional_callable_generates_branch_split() { + let source = r#" + operation Apply(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + + operation Main() : Unit { + use q = Qubit(); + let n = 2; + mutable op = H; + if n == 0 { + op = X; + } elif n == 1 { + op = Y; + } else { + op = Z; + } + Apply(op, q); + } + "#; + check_errors(source, &expect!["(no error)"]); + let targets = callable_call_targets_after_defunc(source, "Main"); + assert!( + targets.contains(&"Apply{X}".to_string()) + && targets.contains(&"Apply{Y}".to_string()) + && targets.contains(&"Apply{Z}".to_string()), + "branch split should call X, Y, and Z specializations, got {targets:?}" + ); +} + +#[test] +fn identity_closure_peephole_replaces_wrapper() { + check( + r#" + operation Apply(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + + operation Main() : Unit { + use q = Qubit(); + let wrapper = q => H(q); + Apply(wrapper, q); + } + "#, + &expect![[r#" + Apply{H}: input_ty=Qubit + Main: input_ty=Unit"#]], + ); +} + +#[test] +fn excessive_specializations_warning_emitted() { + // A HOF called with > 10 different concrete closures triggers the + // ExcessiveSpecializations warning. Each distinct Rx(angle, _) partial + // application with a different angle creates a distinct closure, and + // all closures map to the same functorless Apply variant. + check_errors( + r#" + operation Apply(op : Qubit => Unit, q : Qubit) : Unit { op(q); } + operation Main() : Unit { + use q = Qubit(); + Apply(q1 => Rx(1.0, q1), q); + Apply(q1 => Rx(2.0, q1), q); + Apply(q1 => Rx(3.0, q1), q); + Apply(q1 => Rx(4.0, q1), q); + Apply(q1 => Rx(5.0, q1), q); + Apply(q1 => Rx(6.0, q1), q); + Apply(q1 => Rx(7.0, q1), q); + Apply(q1 => Rx(8.0, q1), q); + Apply(q1 => Rx(9.0, q1), q); + Apply(q1 => Rx(10.0, q1), q); + Apply(q1 => Rx(11.0, q1), q); + } + "#, + &expect![[r#" + higher-order function `Apply` generated 11 specializations, exceeding the warning threshold"#]], + ); +} + +#[test] +fn below_threshold_no_excessive_specializations_warning() { + // A HOF with exactly 10 specializations should NOT trigger the warning. + check_errors( + r#" + operation Apply(op : Qubit => Unit, q : Qubit) : Unit { op(q); } + operation Main() : Unit { + use q = Qubit(); + Apply(H, q); + Apply(X, q); + Apply(Y, q); + Apply(Z, q); + Apply(S, q); + Apply(T, q); + Apply(I, q); + Apply(q1 => Rx(1.0, q1), q); + Apply(q1 => Rx(2.0, q1), q); + Apply(q1 => Rx(3.0, q1), q); + } + "#, + &expect!["(no error)"], + ); +} + +#[test] +fn excessive_specializations_warning_does_not_block_compilation() { + // A program that triggers ExcessiveSpecializations should still compile + // successfully — the warning is non-fatal. We verify by running the + // full defunctionalization and checking PostDefunc invariants hold. + let (mut fir_store, fir_pkg_id) = compile_to_monomorphized_fir( + r#" + operation Apply(op : Qubit => Unit, q : Qubit) : Unit { op(q); } + operation Main() : Unit { + use q = Qubit(); + Apply(q1 => Rx(1.0, q1), q); + Apply(q1 => Rx(2.0, q1), q); + Apply(q1 => Rx(3.0, q1), q); + Apply(q1 => Rx(4.0, q1), q); + Apply(q1 => Rx(5.0, q1), q); + Apply(q1 => Rx(6.0, q1), q); + Apply(q1 => Rx(7.0, q1), q); + Apply(q1 => Rx(8.0, q1), q); + Apply(q1 => Rx(9.0, q1), q); + Apply(q1 => Rx(10.0, q1), q); + Apply(q1 => Rx(11.0, q1), q); + } + "#, + ); + let mut assigner = qsc_fir::assigner::Assigner::from_package(fir_store.get(fir_pkg_id)); + let errors = defunctionalize(&mut fir_store, fir_pkg_id, &mut assigner); + + // Should have exactly one warning, no fatal errors. + let warnings: Vec<_> = errors + .iter() + .filter(|e| matches!(e, super::super::Error::ExcessiveSpecializations(..))) + .collect(); + let fatal: Vec<_> = errors + .iter() + .filter(|e| !matches!(e, super::super::Error::ExcessiveSpecializations(..))) + .collect(); + assert_eq!(warnings.len(), 1, "expected exactly one warning"); + assert!(fatal.is_empty(), "expected no fatal errors, got: {fatal:?}"); + + // PostDefunc invariants must still hold. + fir_invariants::check(&fir_store, fir_pkg_id, InvariantLevel::PostDefunc); +} + +#[test] +fn zero_capture_conditional_alias_dispatches_correctly() { + let source = r#" + operation ZeroCaptureConditionalAlias(q : Qubit, useAdj : Bool) : Unit { + let u = if useAdj { Adjoint S } else { S }; + u(q); + } + operation Main() : Unit { + use q = Qubit(); + ZeroCaptureConditionalAlias(q, true); + } + "#; + check( + source, + &expect![[r#" + Main: input_ty=Unit + ZeroCaptureConditionalAlias: input_ty=(Qubit, Bool)"#]], + ); + let targets = callable_call_targets_after_defunc(source, "ZeroCaptureConditionalAlias"); + assert!( + targets.contains(&"Adjoint S".to_string()) && targets.contains(&"S".to_string()), + "conditional alias should preserve both S and Adjoint S dispatch targets, got {targets:?}" + ); +} diff --git a/source/compiler/qsc_fir_transforms/src/defunctionalize/types.rs b/source/compiler/qsc_fir_transforms/src/defunctionalize/types.rs new file mode 100644 index 0000000000..149f97c1d9 --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/defunctionalize/types.rs @@ -0,0 +1,408 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! Shared types for the defunctionalization pass. +//! +//! These types are used across the analysis, specialization, and rewrite +//! modules to communicate discovered callable parameters, call sites, +//! concrete callable resolutions, and specialization keys. + +#[cfg(test)] +mod tests; + +use miette::Diagnostic; +use rustc_hash::FxHashMap; +use thiserror::Error; + +use qsc_data_structures::functors::FunctorApp; +use qsc_data_structures::span::Span; +use qsc_fir::fir::{ + ExprId, ExprKind, Functor, ItemId, LocalItemId, LocalVarId, Package, PackageLookup, PatId, UnOp, +}; +use qsc_fir::ty::Ty; + +/// A callable parameter detected in a higher-order function declaration. +#[derive(Clone, Debug)] +pub struct CallableParam { + /// The HOF containing this parameter. + pub callable_id: LocalItemId, + /// The pattern node for the parameter. + pub param_pat_id: PatId, + /// The outer input-parameter slot selected before any nested tuple + /// traversal. Single-parameter callables always use `0`. + pub top_level_param: usize, + /// The tuple-field path relative to `top_level_param`. + pub field_path: Vec, + /// The local variable bound by the parameter. + pub param_var: LocalVarId, + /// The Arrow type of the parameter. + pub param_ty: Ty, +} + +impl CallableParam { + #[must_use] + pub fn new( + callable_id: LocalItemId, + param_pat_id: PatId, + top_level_param: usize, + field_path: Vec, + param_var: LocalVarId, + param_ty: Ty, + ) -> Self { + Self { + callable_id, + param_pat_id, + top_level_param, + field_path, + param_var, + param_ty, + } + } +} + +/// A call site where a HOF is called with a concrete callable argument. +#[derive(Clone, Debug)] +pub struct CallSite { + /// The Call expression. + pub call_expr_id: ExprId, + /// The HOF being called. + pub hof_item_id: ItemId, + /// Resolved callable argument. + pub callable_arg: ConcreteCallable, + /// Expression for the callable argument. + pub arg_expr_id: ExprId, + /// Optional condition `ExprId` for branch-split dispatch. When + /// present, this callee is selected when the condition is true. + /// `None` indicates the default (else) branch. + pub condition: Option, +} + +/// A direct call whose callee expression resolves to a concrete callable value. +#[derive(Clone, Debug)] +pub struct DirectCallSite { + /// The Call expression. + pub call_expr_id: ExprId, + /// Resolved concrete callee. + pub callable: ConcreteCallable, + /// Optional condition `ExprId` for branch-split dispatch. When present, + /// this callee is selected when the condition is true. `None` indicates + /// the default (else) branch. + pub condition: Option, +} + +/// A resolved callable value. +#[derive(Clone, Debug, PartialEq)] +pub enum ConcreteCallable { + /// A direct global callable reference with accumulated functor application. + Global { + item_id: ItemId, + functor: FunctorApp, + }, + /// A closure with captured variables and accumulated functor application. + Closure { + target: LocalItemId, + captures: Vec, + functor: FunctorApp, + }, + /// Cannot be resolved statically. + Dynamic, +} + +/// A variable captured by a closure. +#[derive(Clone, Debug, PartialEq)] +pub struct CapturedVar { + /// The captured local variable. + pub var: LocalVarId, + /// The type of the captured variable. + pub ty: Ty, + /// An optional initializer expression to reuse when the original local is + /// scoped to a block that rewrite will erase. + pub expr: Option, +} + +/// Maximum number of concrete callables tracked in a `Multi` lattice element +/// before degrading to `Dynamic`. +pub(super) const MULTI_CAP: usize = 8; + +/// Reaching-definitions lattice for callable variables. +/// Tracks the set of possible concrete callables at each program point. +#[derive(Clone, Debug)] +pub enum CalleeLattice { + /// No value assigned yet (before first definition). + Bottom, + /// Exactly one known callable. + Single(ConcreteCallable), + /// Multiple known callables (from conditional branches) — up to + /// [`MULTI_CAP`] before degrading to `Dynamic`. + /// + /// Each entry is `(callable, condition)` where `condition` is the + /// `ExprId` of the if-condition that selects this callee. The last + /// entry typically has `None` (the else branch). + Multi(Vec<(ConcreteCallable, Option)>), + /// Too many or unknown callables — cannot resolve. + Dynamic, +} + +impl CalleeLattice { + /// Constructs a lattice element from a resolved [`ConcreteCallable`]. + #[must_use] + pub fn from_concrete(cc: ConcreteCallable) -> Self { + match cc { + ConcreteCallable::Dynamic => Self::Dynamic, + other => Self::Single(other), + } + } + + /// Joins two lattice elements (least upper bound). + /// + /// - `Bottom ⊔ x = x` + /// - `Single(a) ⊔ Single(a) = Single(a)` (when equal) + /// - `Single(a) ⊔ Single(b) = Multi([a, b])` + /// - `Multi(s) ⊔ Single(a) = Multi(s ∪ {a})` (cap at [`MULTI_CAP`] → Dynamic) + /// - `Multi(s1) ⊔ Multi(s2) = Multi(s1 ∪ s2)` (cap at [`MULTI_CAP`] → Dynamic) + /// - `Dynamic ⊔ _ = Dynamic` + #[must_use] + pub fn join(self, other: Self) -> Self { + match (self, other) { + (Self::Bottom, x) | (x, Self::Bottom) => x, + (Self::Dynamic, _) | (_, Self::Dynamic) => Self::Dynamic, + (Self::Single(a), Self::Single(b)) => { + if a == b { + Self::Single(a) + } else { + Self::Multi(vec![(a, None), (b, None)]) + } + } + (Self::Multi(mut s), Self::Single(a)) | (Self::Single(a), Self::Multi(mut s)) => { + if !s.iter().any(|(cc, _)| *cc == a) { + s.push((a, None)); + } + if s.len() > MULTI_CAP { + Self::Dynamic + } else { + Self::Multi(s) + } + } + (Self::Multi(mut s1), Self::Multi(s2)) => { + for (item, cond) in s2 { + if !s1.iter().any(|(cc, _)| *cc == item) { + s1.push((item, cond)); + } + } + if s1.len() > MULTI_CAP { + Self::Dynamic + } else { + Self::Multi(s1) + } + } + } + } + + /// Joins two lattice elements with an associated condition from an + /// if/else branch. `self` is the state from the **true** branch and + /// `other` from the **false** branch. + /// + /// Condition-tag provenance rules: + /// + /// - When the true branch is a `Single(a)` distinct from the false + /// branch, entry `a` is tagged `Some(condition)` and the false-branch + /// entry keeps its existing tag (or `None` for the else case). + /// - When the false branch contributes a new callable via + /// `Multi(true) ⊔ Single(false)`, that callable is appended with + /// `None` (it is the default/else path). + /// - Entries inherited from an existing `Multi` retain their original + /// tags. + /// - If both branches are `Multi` with identical callable sets the + /// original tags from `s1` are kept unchanged; otherwise the join + /// degrades to `Dynamic` because nested dispatch is not yet + /// supported. + #[must_use] + pub fn join_with_condition(self, other: Self, condition: ExprId) -> Self { + match (self, other) { + (Self::Bottom, x) | (x, Self::Bottom) => x, + (Self::Single(a), Self::Single(b)) => { + if a == b { + Self::Single(a) + } else { + Self::Multi(vec![(a, Some(condition)), (b, None)]) + } + } + (Self::Single(a), Self::Multi(mut s)) => { + // a from true branch (conditioned), s from false branch + if !s.iter().any(|(cc, _)| *cc == a) { + s.insert(0, (a, Some(condition))); + } + if s.len() > MULTI_CAP { + Self::Dynamic + } else { + Self::Multi(s) + } + } + // Multi(true) + Single(false): the true branch already has + // multiple callables. Insert the single false-branch callable + // into the set if it is not already present. + (Self::Multi(mut s), Self::Single(b)) => { + if !s.iter().any(|(cc, _)| *cc == b) { + s.push((b, None)); + } + if s.len() > MULTI_CAP { + Self::Dynamic + } else { + Self::Multi(s) + } + } + // Multi from the true branch requires nested dispatch → too + // complex for the current implementation, UNLESS both sides have + // the same callable set (variable was not modified in the branch). + (Self::Multi(s1), Self::Multi(s2)) => { + let same_callables = s1.len() == s2.len() + && s1 + .iter() + .zip(s2.iter()) + .all(|((cc1, _), (cc2, _))| cc1 == cc2); + if same_callables { + Self::Multi(s1) + } else { + Self::Dynamic + } + } + // Dynamic ⊔ _ = Dynamic. + (Self::Dynamic, _) | (_, Self::Dynamic) => Self::Dynamic, + } + } +} + +/// Deduplication key for specializations. Two call sites that share the same +/// `SpecKey` can reuse the same generated dispatch callable. +#[derive(Clone, Debug, Eq, Hash, PartialEq)] +pub struct SpecKey { + /// The HOF being specialized. + pub hof_id: LocalItemId, + /// Hashable representations of the concrete callable arguments. + pub concrete_args: Vec, +} + +/// Hashable variant of [`ConcreteCallable`] used for deduplication. Closures +/// are keyed only by their target and functor (captures are structural, not +/// identity-defining). +#[derive(Clone, Debug, Eq, Hash, PartialEq)] +pub enum ConcreteCallableKey { + /// A direct global callable reference. + Global { + item_id: ItemId, + functor: FunctorApp, + }, + /// A closure keyed by target and functor. + /// + /// Captured variables are intentionally omitted so that two closures + /// with identical targets and functors share a specialization; the + /// captured values are threaded as ordinary arguments at the call site + /// rather than being part of the dispatch identity. + Closure { + target: LocalItemId, + functor: FunctorApp, + }, +} + +/// Per-callable lattice snapshot: maps each callable's `LocalItemId` to the +/// sorted list of `(LocalVarId, CalleeLattice)` entries observed after flow +/// analysis. +pub type LatticeStates = FxHashMap>; + +/// Output of the analysis phase. +#[derive(Clone, Debug, Default)] +pub struct AnalysisResult { + /// Callable parameters with arrow types found in HOF declarations. + pub callable_params: Vec, + /// Call sites where HOFs are invoked with concrete callable arguments. + pub call_sites: Vec, + /// Direct calls whose callee resolves to a concrete callable value. + pub direct_call_sites: Vec, + /// Per-callable lattice states for all callable-typed local variables + /// after flow analysis. + pub lattice_states: LatticeStates, +} + +/// Errors that can occur during defunctionalization. +#[derive(Clone, Debug, Diagnostic, Error)] +pub enum Error { + /// Emitted when a callable argument cannot be statically resolved to a + /// concrete set of callables, typically because the number of conditional + /// branches exceeds `MULTI_CAP`, a conditional has mismatched Multi + /// variants, or a mutable callable variable is reassigned in a loop. + #[error("callable argument could not be resolved statically")] + #[diagnostic(code("Qsc.Defunctionalize.DynamicCallable"))] + #[diagnostic(help("ensure all callable arguments are known at compile time"))] + DynamicCallable(#[label] Span), + + /// Reserved; currently unused. Mutable callable parameters are handled + /// via branch-splitting (resolving to `Multi` in the `CalleeLattice`) + /// rather than producing this error. Retained for future use when + /// rejection of mutable callables becomes appropriate. + #[error("callable parameter is mutably assigned")] + #[diagnostic(code("Qsc.Defunctionalize.MutableCallable"))] + MutableCallable(#[label] Span), + + #[error("specialization leads to infinite recursion")] + #[diagnostic(code("Qsc.Defunctionalize.RecursiveSpecialization"))] + RecursiveSpecialization(#[label] Span), + + #[error( + "defunctionalization did not converge within {0} iterations; {1} callable values remain" + )] + #[diagnostic(code("Qsc.Defunctionalize.FixpointNotReached"))] + #[diagnostic(help("consider reducing the nesting depth of higher-order function chains"))] + FixpointNotReached(usize, usize, #[label("remaining callable value")] Span), + + #[error( + "higher-order function `{0}` generated {1} specializations, exceeding the warning threshold" + )] + #[diagnostic(code("Qsc.Defunctionalize.ExcessiveSpecializations"))] + #[diagnostic(severity(warning))] + #[diagnostic(help( + "consider reducing the number of distinct callable arguments passed to this function" + ))] + ExcessiveSpecializations( + String, + usize, + #[label("excessive specializations generated here")] Span, + ), +} + +/// Composes two `FunctorApp` values. +/// +/// Adjoint toggles (XOR) and controlled counts stack (saturating addition). +/// This correctly handles double-adjoint cancellation: +/// `compose_functors({adj:true, ..}, {adj:true, ..})` yields `{adj:false, ..}`. +#[must_use] +pub fn compose_functors(creation: &FunctorApp, body: &FunctorApp) -> FunctorApp { + FunctorApp { + adjoint: creation.adjoint ^ body.adjoint, + controlled: creation.controlled.saturating_add(body.controlled), + } +} + +/// Recursively strips `UnOp(Functor(Adj|Ctl), inner)` layers from an +/// expression, accumulating the functor applications into a `FunctorApp`. +/// +/// Returns `(base_expr_id, accumulated_functor_app)` where `base_expr_id` +/// is the innermost expression after all functor wrappers are removed. +#[must_use] +pub fn peel_body_functors(package: &Package, expr_id: ExprId) -> (ExprId, FunctorApp) { + let mut current = expr_id; + let mut functor = FunctorApp::default(); + loop { + let expr = package.get_expr(current); + match &expr.kind { + ExprKind::UnOp(UnOp::Functor(Functor::Adj), inner) => { + functor.adjoint = !functor.adjoint; + current = *inner; + } + ExprKind::UnOp(UnOp::Functor(Functor::Ctl), inner) => { + functor.controlled = functor.controlled.saturating_add(1); + current = *inner; + } + _ => return (current, functor), + } + } +} diff --git a/source/compiler/qsc_fir_transforms/src/defunctionalize/types/tests.rs b/source/compiler/qsc_fir_transforms/src/defunctionalize/types/tests.rs new file mode 100644 index 0000000000..95bbafe182 --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/defunctionalize/types/tests.rs @@ -0,0 +1,109 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use super::*; +use qsc_data_structures::functors::FunctorApp; +use qsc_fir::fir::{ExprId, ItemId, LocalItemId, PackageId}; + +fn global(id: usize) -> ConcreteCallable { + ConcreteCallable::Global { + item_id: ItemId { + package: PackageId::from(0), + item: LocalItemId::from(id), + }, + functor: FunctorApp::default(), + } +} + +fn cond() -> ExprId { + ExprId::from(99u32) +} + +#[test] +fn join_with_condition_single_multi_inserts_into_set() { + let a = global(1); + let b = global(2); + let lhs = CalleeLattice::Single(a.clone()); + let rhs = CalleeLattice::Multi(vec![(b.clone(), Some(ExprId::from(50u32)))]); + + let result = lhs.join_with_condition(rhs, cond()); + + match result { + CalleeLattice::Multi(entries) => { + assert_eq!(entries.len(), 2); + assert_eq!(entries[0], (a, Some(cond()))); + assert_eq!(entries[1], (b, Some(ExprId::from(50u32)))); + } + other => panic!("expected Multi, got {other:?}"), + } +} + +#[test] +fn join_with_condition_multi_single_inserts_into_set() { + let a = global(1); + let b = global(2); + let lhs = CalleeLattice::Multi(vec![(a.clone(), Some(ExprId::from(50u32)))]); + let rhs = CalleeLattice::Single(b.clone()); + + let result = lhs.join_with_condition(rhs, cond()); + + match result { + CalleeLattice::Multi(entries) => { + assert_eq!(entries.len(), 2); + assert_eq!(entries[0], (a, Some(ExprId::from(50u32)))); + assert_eq!(entries[1], (b, None)); + } + other => panic!("expected Multi, got {other:?}"), + } +} + +#[test] +fn join_with_condition_single_same_stays_single() { + let a = global(1); + let result = CalleeLattice::Single(a.clone()) + .join_with_condition(CalleeLattice::Single(a.clone()), cond()); + + match result { + CalleeLattice::Single(cc) => assert_eq!(cc, a), + other => panic!("expected Single, got {other:?}"), + } +} + +#[test] +fn join_with_condition_single_different_produces_multi() { + let a = global(1); + let b = global(2); + let result = CalleeLattice::Single(a.clone()) + .join_with_condition(CalleeLattice::Single(b.clone()), cond()); + + match result { + CalleeLattice::Multi(entries) => { + assert_eq!(entries.len(), 2); + assert_eq!(entries[0], (a, Some(cond()))); + assert_eq!(entries[1], (b, None)); + } + other => panic!("expected Multi, got {other:?}"), + } +} + +#[test] +fn join_with_condition_multi_single_cap_exceeded_becomes_dynamic() { + let entries: Vec<(ConcreteCallable, Option)> = (0..MULTI_CAP) + .map(|i| { + ( + global(i), + Some(ExprId::from(u32::try_from(i).expect("id must fit"))), + ) + }) + .collect(); + let extra = global(MULTI_CAP + 10); + let lhs = CalleeLattice::Multi(entries); + let rhs = CalleeLattice::Single(extra); + + let result = lhs.join_with_condition(rhs, cond()); + + assert!( + matches!(result, CalleeLattice::Dynamic), + "expected Dynamic when exceeding MULTI_CAP, got {result:?}" + ); +} diff --git a/source/compiler/qsc_fir_transforms/src/exec_graph_rebuild.rs b/source/compiler/qsc_fir_transforms/src/exec_graph_rebuild.rs new file mode 100644 index 0000000000..02fa37d07f --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/exec_graph_rebuild.rs @@ -0,0 +1,647 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! Rebuilds exec graphs for all reachable callables and the entry expression. +//! +//! After earlier FIR transforms synthesize new expressions or statements with +//! empty ranges, the exec graphs on `SpecDecl` and +//! `Package.entry_exec_graph` are stale. In practice this includes return +//! unification, defunctionalization, UDT erasure, tuple-compare lowering, +//! SROA, and argument promotion. This pass reconstructs every graph from +//! scratch by walking the FIR and emitting the same node sequences that the +//! original lowerer would have produced. +//! +//! ## Transformation Shape +//! +//! **Before:** Callable specs and the entry expression carry stale +//! `exec_graph_range` values — often `EMPTY_EXEC_RANGE` sentinels inserted +//! by earlier passes. The exec graph vectors may reference deleted or +//! renumbered nodes. +//! +//! **After:** Every reachable callable spec and the entry expression has a +//! freshly built exec graph. Ranges on individual `Expr` and `Stmt` nodes +//! index into the rebuilt vectors. +//! +//! ## Borrow-Splitting Strategy +//! +//! The rebuild cannot hold both `&Package` (for reading expressions) and +//! `&mut Package` (for writing exec graphs) simultaneously. This is solved +//! by accumulating deferred writes in `RangeUpdates`: during the read-only +//! graph-building walk, expression and statement ranges are recorded as +//! `(ExprId, Range)` and `(StmtId, Range)` +//! pairs. After building completes and the immutable borrow ends, +//! `apply_ranges` writes each range back to the corresponding `Expr` or +//! `Stmt` under a mutable borrow. +//! +//! ## `ExecGraphBuilder` Delegation +//! +//! Graph nodes are emitted via `ExecGraphBuilder` from `qsc_lowerer`, which +//! maintains paired no-debug and debug node vectors. This ensures the rebuilt +//! graphs match the format produced by the original lowering pass. +//! +//! ## See Also +//! +//! - `qsc_lowerer::exec_graph` — The `ExecGraphBuilder` that emits graph +//! nodes. The rebuild pass re-uses this builder to ensure graph format +//! fidelity with the original lowering pass. + +#[cfg(test)] +mod tests; + +use std::ops::Range; + +use qsc_fir::fir::{ + BinOp, BlockId, CallableImpl, ExecGraphDebugNode, ExecGraphIdx, ExecGraphNode, ExprId, + ExprKind, ItemKind, LocalItemId, Package, PackageId, PackageLookup, PackageStore, + SpecDecl as FirSpecDecl, StmtId, StmtKind, StoreItemId, StringComponent, +}; +use qsc_fir::ty::Ty; +use qsc_lowerer::ExecGraphBuilder; + +use crate::reachability::{collect_reachable_from_entry, collect_reachable_with_seeds}; + +/// Side-table collecting deferred `exec_graph_range` updates. +/// Populated during the read-only graph-building pass, then applied in a +/// separate write pass to avoid simultaneous mutable and immutable borrows. +#[derive(Default)] +struct RangeUpdates { + exprs: Vec<(ExprId, Range)>, + stmts: Vec<(StmtId, Range)>, +} + +/// Applies collected range updates to package expressions and statements. +/// +/// Invoked once per specialization (not once at the end of the pass). Each +/// call writes the ranges gathered for that spec back to the package +/// before the next specialization begins rebuilding. +fn apply_ranges(package: &mut Package, ranges: &RangeUpdates) { + for (id, range) in &ranges.exprs { + package + .exprs + .get_mut(*id) + .expect("expr must exist") + .exec_graph_range = range.clone(); + } + for (id, range) in &ranges.stmts { + package + .stmts + .get_mut(*id) + .expect("stmt must exist") + .exec_graph_range = range.clone(); + } +} + +/// Collected spec info for a single callable — avoids holding a `&Package` +/// reference while mutating. +struct SpecInfo { + block: BlockId, + /// Which specialization on the containing callable should receive the + /// rebuilt graph during write-back. + kind: SpecKind, +} + +/// Which specialization within a `CallableImpl`. +#[derive(Clone, Copy)] +enum SpecKind { + /// The default callable body implementation. + Body, + /// The adjoint specialization. + Adj, + /// The controlled specialization. + Ctl, + /// The controlled-adjoint specialization. + CtlAdj, + /// A simulatable intrinsic with an explicit body block. + SimulatableIntrinsic, +} + +/// All spec infos for one callable item, collected while holding `&Package`. +struct CallableSpecs { + item_id: LocalItemId, + specs: Vec, +} + +/// Rebuilds exec graphs for every reachable callable and the entry expression +/// in the given package. When `pinned_items` is non-empty, uses seed-based +/// reachability to include pinned callables that are not entry-reachable. +/// +/// This must be called after all FIR transforms have completed. The function +/// is idempotent — calling it multiple times produces the same result. +pub fn rebuild_exec_graphs( + store: &mut PackageStore, + package_id: PackageId, + pinned_items: &[StoreItemId], +) { + // Early return if there is no entry expression — nothing to rebuild. + { + let package = store.get(package_id); + if package.entry.is_none() { + return; + } + } + + let reachable = if pinned_items.is_empty() { + collect_reachable_from_entry(store, package_id) + } else { + collect_reachable_with_seeds(store, package_id, pinned_items) + }; + + let collected = collect_callable_specs(store, package_id, &reachable); + rebuild_callable_exec_graphs(store, package_id, &collected); + rebuild_entry_exec_graph(store, package_id); +} + +/// Collects the block IDs for every spec in every reachable callable that +/// lives in this package (cross-package items are not rebuilt). +fn collect_callable_specs( + store: &PackageStore, + package_id: PackageId, + reachable: &rustc_hash::FxHashSet, +) -> Vec { + let mut collected: Vec = Vec::new(); + let package = store.get(package_id); + for item_id in reachable { + if item_id.package != package_id { + continue; + } + let item = package.get_item(item_id.item); + let decl = match &item.kind { + ItemKind::Callable(decl) => decl.as_ref(), + _ => continue, + }; + let specs = collect_specs_from_impl(&decl.implementation); + if !specs.is_empty() { + collected.push(CallableSpecs { + item_id: item_id.item, + specs, + }); + } + } + collected +} + +/// Extracts `SpecInfo` entries from a callable implementation. +fn collect_specs_from_impl(implementation: &CallableImpl) -> Vec { + let mut specs = Vec::new(); + match implementation { + CallableImpl::Intrinsic => {} + CallableImpl::Spec(spec_impl) => { + specs.push(SpecInfo { + block: spec_impl.body.block, + kind: SpecKind::Body, + }); + if let Some(adj) = &spec_impl.adj { + specs.push(SpecInfo { + block: adj.block, + kind: SpecKind::Adj, + }); + } + if let Some(ctl) = &spec_impl.ctl { + specs.push(SpecInfo { + block: ctl.block, + kind: SpecKind::Ctl, + }); + } + if let Some(ctl_adj) = &spec_impl.ctl_adj { + specs.push(SpecInfo { + block: ctl_adj.block, + kind: SpecKind::CtlAdj, + }); + } + } + CallableImpl::SimulatableIntrinsic(spec) => { + specs.push(SpecInfo { + block: spec.block, + kind: SpecKind::SimulatableIntrinsic, + }); + } + } + specs +} + +/// Rebuilds and writes back the exec graph for each collected callable spec. +fn rebuild_callable_exec_graphs( + store: &mut PackageStore, + package_id: PackageId, + collected: &[CallableSpecs], +) { + for callable in collected { + for spec_info in &callable.specs { + // Build graph — immutable borrow. + let (graph, ranges) = { + let package = store.get(package_id); + let mut builder = ExecGraphBuilder::default(); + let mut ranges = RangeUpdates::default(); + rebuild_block(package, &mut builder, spec_info.block, &mut ranges); + (builder.take(), ranges) + }; + + // Write back — mutable borrow. + let package = store.get_mut(package_id); + apply_ranges(package, &ranges); + + let target_spec = get_spec_decl_mut(package, callable.item_id, spec_info.kind); + target_spec.exec_graph = graph; + } + } +} + +/// Returns a mutable reference to the spec decl identified by `kind` on the +/// callable at `item_id`. +fn get_spec_decl_mut( + package: &mut Package, + item_id: LocalItemId, + kind: SpecKind, +) -> &mut FirSpecDecl { + let item = package.items.get_mut(item_id).expect("item must exist"); + let decl = match &mut item.kind { + ItemKind::Callable(decl) => decl.as_mut(), + _ => unreachable!("already verified callable"), + }; + match kind { + SpecKind::Body => match &mut decl.implementation { + CallableImpl::Spec(si) => &mut si.body, + _ => unreachable!("already verified Spec"), + }, + SpecKind::Adj => match &mut decl.implementation { + CallableImpl::Spec(si) => si.adj.as_mut().expect("adj must exist"), + _ => unreachable!("already verified Spec"), + }, + SpecKind::Ctl => match &mut decl.implementation { + CallableImpl::Spec(si) => si.ctl.as_mut().expect("ctl must exist"), + _ => unreachable!("already verified Spec"), + }, + SpecKind::CtlAdj => match &mut decl.implementation { + CallableImpl::Spec(si) => si.ctl_adj.as_mut().expect("ctl_adj must exist"), + _ => unreachable!("already verified Spec"), + }, + SpecKind::SimulatableIntrinsic => match &mut decl.implementation { + CallableImpl::SimulatableIntrinsic(spec) => spec, + _ => unreachable!("already verified SimulatableIntrinsic"), + }, + } +} + +/// Rebuilds the entry exec graph from the package's entry expression. +fn rebuild_entry_exec_graph(store: &mut PackageStore, package_id: PackageId) { + let entry_id = store + .get(package_id) + .entry + .expect("entry must exist; caller guards against missing entry"); + let (graph, ranges) = { + let package = store.get(package_id); + let mut builder = ExecGraphBuilder::default(); + let mut ranges = RangeUpdates::default(); + rebuild_expr(package, &mut builder, entry_id, &mut ranges); + (builder.take(), ranges) + }; + let package = store.get_mut(package_id); + package.entry_exec_graph = graph; + apply_ranges(package, &ranges); +} + +/// Rebuilds the execution graph for a block by visiting each statement and +/// appending a `Unit` node when the block is empty or does not end with +/// an expression statement. +fn rebuild_block( + package: &Package, + builder: &mut ExecGraphBuilder, + block_id: BlockId, + ranges: &mut RangeUpdates, +) { + builder.debug_push(ExecGraphDebugNode::PushScope); + + let block = package.get_block(block_id); + let stmts = block.stmts.clone(); + + let set_unit = stmts.is_empty() + || !matches!( + package.get_stmt(*stmts.last().expect("non-empty")).kind, + StmtKind::Expr(..) + ); + + for &stmt_id in &stmts { + rebuild_stmt(package, builder, stmt_id, ranges); + } + + if set_unit { + builder.push(ExecGraphNode::Unit); + } + + builder.debug_push(ExecGraphDebugNode::BlockEnd(block_id)); + builder.debug_push(ExecGraphDebugNode::PopScope); +} + +/// Rebuilds the execution graph for a single statement. `Local` bindings +/// emit a `Bind` node after the initializer expression; `Item` statements +/// are no-ops. +fn rebuild_stmt( + package: &Package, + builder: &mut ExecGraphBuilder, + stmt_id: StmtId, + ranges: &mut RangeUpdates, +) { + let graph_start = builder.len(); + builder.debug_push(ExecGraphDebugNode::Stmt(stmt_id)); + + let kind = package.get_stmt(stmt_id).kind.clone(); + match kind { + StmtKind::Expr(expr_id) | StmtKind::Semi(expr_id) => { + rebuild_expr(package, builder, expr_id, ranges); + } + StmtKind::Local(_, pat_id, expr_id) => { + rebuild_expr(package, builder, expr_id, ranges); + builder.push(ExecGraphNode::Bind(pat_id)); + } + StmtKind::Item(_) => {} + } + + ranges.stmts.push((stmt_id, graph_start..builder.len())); +} + +/// Rebuilds the execution graph for an expression, recursively visiting +/// sub-expressions. Control-flow expressions (`If`, `While`, short-circuit +/// operators) produce jump nodes; assignments use `truncate` to discard +/// the LHS target nodes; multi-operand expressions interleave `Store` +/// nodes to preserve intermediate values on the evaluation stack. +#[allow(clippy::too_many_lines)] +fn rebuild_expr( + package: &Package, + builder: &mut ExecGraphBuilder, + expr_id: ExprId, + ranges: &mut RangeUpdates, +) { + let graph_start = builder.len(); + let expr = package.get_expr(expr_id); + let kind = expr.kind.clone(); + + match kind { + // Control flow (no trailing Expr(id)) + ExprKind::BinOp(BinOp::AndL, lhs, rhs) => { + rebuild_expr(package, builder, lhs, ranges); + let idx = builder.len(); + builder.push(ExecGraphNode::Jump(0)); + rebuild_expr(package, builder, rhs, ranges); + builder.set_with_arg(ExecGraphNode::JumpIfNot, idx, builder.len()); + } + + ExprKind::BinOp(BinOp::OrL, lhs, rhs) => { + rebuild_expr(package, builder, lhs, ranges); + let idx = builder.len(); + builder.push(ExecGraphNode::Jump(0)); + rebuild_expr(package, builder, rhs, ranges); + builder.set_with_arg(ExecGraphNode::JumpIf, idx, builder.len()); + } + + ExprKind::Block(block_id) => { + rebuild_block(package, builder, block_id, ranges); + } + + ExprKind::If(cond, if_true, if_false) => { + rebuild_expr(package, builder, cond, ranges); + let branch_idx = builder.len(); + builder.push(ExecGraphNode::Jump(0)); + rebuild_expr(package, builder, if_true, ranges); + + if let Some(else_id) = if_false { + // With else branch. + let idx = builder.len(); + builder.push(ExecGraphNode::Jump(0)); + rebuild_expr(package, builder, else_id, ranges); + builder.set_with_arg(ExecGraphNode::Jump, idx, builder.len()); + let else_idx = idx + 1; + builder.set_with_arg(ExecGraphNode::JumpIfNot, branch_idx, else_idx); + } else { + // Without else — produces Unit. + let idx = builder.len(); + builder.push(ExecGraphNode::Unit); + builder.set_with_arg(ExecGraphNode::JumpIfNot, branch_idx, idx); + } + } + + ExprKind::While(cond, body_block) => { + builder.debug_push(ExecGraphDebugNode::PushLoopScope(expr_id)); + let cond_idx = builder.len(); + rebuild_expr(package, builder, cond, ranges); + let idx = builder.len(); + builder.push(ExecGraphNode::Jump(0)); + builder.debug_push(ExecGraphDebugNode::LoopIteration); + rebuild_block(package, builder, body_block, ranges); + builder.push_with_arg(ExecGraphNode::Jump, cond_idx); + builder.set_with_arg(ExecGraphNode::JumpIfNot, idx, builder.len()); + builder.debug_push(ExecGraphDebugNode::PopScope); + builder.push(ExecGraphNode::Unit); + } + + ExprKind::Return(inner) => { + rebuild_expr(package, builder, inner, ranges); + builder.push_ret(); + } + + // Assignments (trailing Expr(id) + Unit) + ExprKind::Assign(lhs, rhs) => { + // Visit the LHS to record its range, then truncate the emitted + // nodes — the LHS is an assignment target, not a value to evaluate. + let idx = builder.len(); + rebuild_expr(package, builder, lhs, ranges); + builder.truncate(idx); + rebuild_expr(package, builder, rhs, ranges); + builder.push(ExecGraphNode::Expr(expr_id)); + builder.push(ExecGraphNode::Unit); + } + + ExprKind::AssignOp(op, lhs, rhs) => { + let idx = builder.len(); + let is_array = matches!(package.get_expr(lhs).ty, Ty::Array(..)); + rebuild_expr(package, builder, lhs, ranges); + + if is_array { + // Array assignment targets are not evaluated — truncate the + // LHS nodes so only the RHS value remains on the stack. + builder.truncate(idx); + } + + let idx = builder.len(); + if matches!(op, BinOp::AndL | BinOp::OrL) { + builder.push(ExecGraphNode::Jump(0)); + } else if !is_array { + builder.push(ExecGraphNode::Store); + } + + rebuild_expr(package, builder, rhs, ranges); + + match op { + BinOp::AndL => { + builder.set_with_arg(ExecGraphNode::JumpIfNot, idx, builder.len()); + } + BinOp::OrL => { + builder.set_with_arg(ExecGraphNode::JumpIf, idx, builder.len()); + } + _ => {} + } + + builder.push(ExecGraphNode::Expr(expr_id)); + builder.push(ExecGraphNode::Unit); + } + + ExprKind::AssignField(container, _field, replace) => { + rebuild_expr(package, builder, replace, ranges); + builder.push(ExecGraphNode::Store); + rebuild_expr(package, builder, container, ranges); + builder.push(ExecGraphNode::Expr(expr_id)); + builder.push(ExecGraphNode::Unit); + } + + ExprKind::AssignIndex(container, index, replace) => { + rebuild_expr(package, builder, index, ranges); + builder.push(ExecGraphNode::Store); + rebuild_expr(package, builder, replace, ranges); + // Truncate: container is the assignment target, not a value. + let idx = builder.len(); + rebuild_expr(package, builder, container, ranges); + builder.truncate(idx); + builder.push(ExecGraphNode::Expr(expr_id)); + builder.push(ExecGraphNode::Unit); + } + + // Multi-operand with Store (trailing Expr(id)) + // Each sub-expression is followed by a Store node that pushes its + // value onto the evaluation stack, keeping all operands available + // when the final Expr node evaluates the compound expression. + // + // Note: `ExprKind::Array` emits a `Store` after each item (items + // are kept on the value stack for the final `Expr` node), while + // `ExprKind::ArrayLit` pops after each item. This asymmetry + // matches the evaluator's expected stack shape for the two + // array-construction variants. + ExprKind::Array(items) | ExprKind::Tuple(items) => { + for item_id in &items { + rebuild_expr(package, builder, *item_id, ranges); + builder.push(ExecGraphNode::Store); + } + builder.push(ExecGraphNode::Expr(expr_id)); + } + + ExprKind::ArrayLit(items) => { + for item_id in &items { + rebuild_expr(package, builder, *item_id, ranges); + builder.pop(); + } + builder.push(ExecGraphNode::Expr(expr_id)); + } + + ExprKind::ArrayRepeat(val, size) => { + rebuild_expr(package, builder, val, ranges); + builder.push(ExecGraphNode::Store); + rebuild_expr(package, builder, size, ranges); + builder.push(ExecGraphNode::Expr(expr_id)); + } + + ExprKind::BinOp(_op, lhs, rhs) => { + // Non-short-circuit binary op (AndL/OrL handled above). + // Store saves the LHS value so both operands are available + // when the Expr node evaluates the operation. + rebuild_expr(package, builder, lhs, ranges); + builder.push(ExecGraphNode::Store); + rebuild_expr(package, builder, rhs, ranges); + builder.push(ExecGraphNode::Expr(expr_id)); + } + + ExprKind::Call(callee, arg) => { + // Evaluate and store the callee, then evaluate the argument. + // The Expr node performs the actual call dispatch at runtime. + rebuild_expr(package, builder, callee, ranges); + builder.push(ExecGraphNode::Store); + rebuild_expr(package, builder, arg, ranges); + builder.push(ExecGraphNode::Expr(expr_id)); + } + + ExprKind::Index(container, index) => { + rebuild_expr(package, builder, container, ranges); + builder.push(ExecGraphNode::Store); + rebuild_expr(package, builder, index, ranges); + builder.push(ExecGraphNode::Expr(expr_id)); + } + + ExprKind::UpdateField(record, _field, replace) => { + rebuild_expr(package, builder, replace, ranges); + builder.push(ExecGraphNode::Store); + rebuild_expr(package, builder, record, ranges); + builder.push(ExecGraphNode::Expr(expr_id)); + } + + ExprKind::UpdateIndex(lhs, mid, rhs) => { + rebuild_expr(package, builder, mid, ranges); + builder.push(ExecGraphNode::Store); + rebuild_expr(package, builder, rhs, ranges); + builder.push(ExecGraphNode::Store); + rebuild_expr(package, builder, lhs, ranges); + builder.push(ExecGraphNode::Expr(expr_id)); + } + + ExprKind::Range(start, step, end) => { + if let Some(s) = start { + rebuild_expr(package, builder, s, ranges); + builder.push(ExecGraphNode::Store); + } + if let Some(st) = step { + rebuild_expr(package, builder, st, ranges); + builder.push(ExecGraphNode::Store); + } + if let Some(e) = end { + rebuild_expr(package, builder, e, ranges); + } + builder.push(ExecGraphNode::Expr(expr_id)); + } + + ExprKind::String(components) => { + for component in &components { + if let StringComponent::Expr(comp_expr_id) = component { + rebuild_expr(package, builder, *comp_expr_id, ranges); + builder.push(ExecGraphNode::Store); + } + } + builder.push(ExecGraphNode::Expr(expr_id)); + } + + // Simple variants (just Expr(id)) + ExprKind::Lit(..) | ExprKind::Var(..) => { + builder.push(ExecGraphNode::Expr(expr_id)); + } + + ExprKind::Fail(msg) => { + rebuild_expr(package, builder, msg, ranges); + builder.push(ExecGraphNode::Expr(expr_id)); + } + + ExprKind::Field(container, _) => { + rebuild_expr(package, builder, container, ranges); + builder.push(ExecGraphNode::Expr(expr_id)); + } + + ExprKind::UnOp(_, operand) => { + rebuild_expr(package, builder, operand, ranges); + builder.push(ExecGraphNode::Expr(expr_id)); + } + + // Eliminated variant + // + // `ExprKind::Struct` must be unreachable here: the UDT erasure pass + // establishes [`crate::invariants::InvariantLevel::PostUdtErase`], + // which guarantees that no `ExprKind::Struct` survives into + // exec-graph rebuild. + ExprKind::Struct(..) => { + panic!("Struct expressions should have been eliminated by udt_erase"); + } + + // Eliminated variant + // + // Closures and holes are forbidden by the `PostDefunc` invariant, + // so they are unreachable at this pipeline stage. + ExprKind::Closure(..) | ExprKind::Hole => { + panic!("Closure and hole expressions should have been eliminated by post_defunc"); + } + } + + ranges.exprs.push((expr_id, graph_start..builder.len())); +} diff --git a/source/compiler/qsc_fir_transforms/src/exec_graph_rebuild/tests.rs b/source/compiler/qsc_fir_transforms/src/exec_graph_rebuild/tests.rs new file mode 100644 index 0000000000..d23ff2fd18 --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/exec_graph_rebuild/tests.rs @@ -0,0 +1,952 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +// Proptest applicability: N/A — exec_graph_rebuild is a structural reconstruction pass whose +// correctness is that rebuilt graphs match the format the original lowerer would produce. +// There is no semantic equivalence observable at the Q# level. Testing requires comparing +// graph node sequences, which is better served by targeted snapshot tests. + +use crate::test_utils::{ + PipelineStage, compile_and_run_pipeline_to, expr_kind_short, stmt_kind_short, +}; +use expect_test::{Expect, expect}; +use indoc::indoc; +use qsc_fir::fir::{ + CallableDecl, CallableImpl, ExecGraphConfig, ExecGraphDebugNode, ExecGraphNode, ExprId, Field, + ItemKind, LocalVarId, PackageLookup, PatId, PatKind, Res, StoreItemId, +}; +use rustc_hash::FxHashMap; + +#[derive(Clone, Copy)] +enum CallableSpecKind { + Body, + Adj, + Ctl, + CtlAdj, + SimulatableIntrinsic, +} + +/// Formats the body spec exec graph of the entry callable as a string for +/// snapshot testing. Each node is printed on its own line with its index. +fn format_callable_exec_graph( + store: &qsc_fir::fir::PackageStore, + pkg_id: qsc_fir::fir::PackageId, + config: ExecGraphConfig, +) -> String { + let package = store.get(pkg_id); + let reachable = crate::reachability::collect_reachable_from_entry(store, pkg_id); + + // Find the entry callable (the one in our package). + for store_id in &reachable { + if store_id.package != pkg_id { + continue; + } + let item = package.get_item(store_id.item); + if let ItemKind::Callable(decl) = &item.kind + && decl.name.name.as_ref() == "Main" + && let CallableImpl::Spec(spec) = &decl.implementation + { + let graph = spec.body.exec_graph.clone().select(config); + return graph + .iter() + .enumerate() + .map(|(i, node)| match node { + ExecGraphNode::Expr(expr_id) => { + let label = expr_kind_short(package, *expr_id); + format!("{i}: Expr({expr_id:?}) [{label}]") + } + ExecGraphNode::Debug(ExecGraphDebugNode::Stmt(stmt_id)) => { + let label = stmt_kind_short(package, *stmt_id); + format!("{i}: Debug(Stmt({stmt_id:?})) [{label}]") + } + _ => format!("{i}: {node:?}"), + }) + .collect::>() + .join("\n"); + } + } + panic!("Main callable not found"); +} + +fn find_callable<'a>(package: &'a qsc_fir::fir::Package, callable_name: &str) -> &'a CallableDecl { + package + .items + .values() + .find_map(|item| match &item.kind { + ItemKind::Callable(decl) if decl.name.name.as_ref() == callable_name => { + Some(decl.as_ref()) + } + _ => None, + }) + .unwrap_or_else(|| panic!("callable '{callable_name}' not found")) +} + +fn collect_pat_names( + package: &qsc_fir::fir::Package, + pat_id: PatId, + names: &mut FxHashMap, +) { + let pat = package.get_pat(pat_id); + match &pat.kind { + PatKind::Bind(ident) => { + names.insert(ident.id, ident.name.to_string()); + } + PatKind::Tuple(sub_pats) => { + for &sub_pat_id in sub_pats { + collect_pat_names(package, sub_pat_id, names); + } + } + PatKind::Discard => {} + } +} + +fn callable_local_names( + package: &qsc_fir::fir::Package, + callable: &CallableDecl, +) -> FxHashMap { + let mut names = FxHashMap::default(); + collect_pat_names(package, callable.input, &mut names); + + match &callable.implementation { + CallableImpl::Spec(spec_impl) => { + for spec in std::iter::once(&spec_impl.body) + .chain(spec_impl.adj.iter()) + .chain(spec_impl.ctl.iter()) + .chain(spec_impl.ctl_adj.iter()) + { + if let Some(input_pat) = spec.input { + collect_pat_names(package, input_pat, &mut names); + } + } + } + CallableImpl::SimulatableIntrinsic(spec) => { + if let Some(input_pat) = spec.input { + collect_pat_names(package, input_pat, &mut names); + } + } + CallableImpl::Intrinsic => {} + } + + names +} + +fn bind_label(package: &qsc_fir::fir::Package, pat_id: PatId) -> String { + let pat = package.get_pat(pat_id); + match &pat.kind { + PatKind::Bind(ident) => format!("Bind({})", ident.name), + PatKind::Tuple(_) => "Bind(tuple)".to_string(), + PatKind::Discard => "Bind(_)".to_string(), + } +} + +fn item_name(store: &qsc_fir::fir::PackageStore, item_id: &qsc_fir::fir::ItemId) -> String { + let package = store.get(item_id.package); + match &package.get_item(item_id.item).kind { + ItemKind::Callable(decl) => decl.name.name.to_string(), + _ => format!("{item_id:?}"), + } +} + +fn semantic_expr_label( + store: &qsc_fir::fir::PackageStore, + package: &qsc_fir::fir::Package, + local_names: &FxHashMap, + expr_id: ExprId, +) -> String { + let expr = package.get_expr(expr_id); + match &expr.kind { + qsc_fir::fir::ExprKind::Field(record_id, Field::Path(path)) => { + let mut formatted = semantic_expr_label(store, package, local_names, *record_id); + for index in &path.indices { + formatted.push('.'); + formatted.push_str(&index.to_string()); + } + formatted + } + qsc_fir::fir::ExprKind::Lit(lit) => format!("Lit({lit:?})"), + qsc_fir::fir::ExprKind::Tuple(items) => format!("Tuple(len={})", items.len()), + qsc_fir::fir::ExprKind::UnOp(op, operand_id) => format!( + "{op:?}({})", + semantic_expr_label(store, package, local_names, *operand_id) + ), + qsc_fir::fir::ExprKind::Var(Res::Item(item_id), _) => item_name(store, item_id), + qsc_fir::fir::ExprKind::Var(Res::Local(local_id), _) => { + local_names.get(local_id).map_or_else( + || format!("Var({local_id:?})"), + |name| format!("Var({name})"), + ) + } + _ => expr_kind_short(package, expr_id), + } +} + +fn format_callable_spec_exec_graph( + store: &qsc_fir::fir::PackageStore, + pkg_id: qsc_fir::fir::PackageId, + callable_name: &str, + spec_kind: CallableSpecKind, +) -> String { + let package = store.get(pkg_id); + let callable = find_callable(package, callable_name); + let local_names = callable_local_names(package, callable); + let spec = match (spec_kind, &callable.implementation) { + (CallableSpecKind::Body, CallableImpl::Spec(spec_impl)) => &spec_impl.body, + (CallableSpecKind::Adj, CallableImpl::Spec(spec_impl)) => { + spec_impl.adj.as_ref().expect("adjoint spec should exist") + } + (CallableSpecKind::Ctl, CallableImpl::Spec(spec_impl)) => spec_impl + .ctl + .as_ref() + .expect("controlled spec should exist"), + (CallableSpecKind::CtlAdj, CallableImpl::Spec(spec_impl)) => spec_impl + .ctl_adj + .as_ref() + .expect("controlled adjoint spec should exist"), + (CallableSpecKind::SimulatableIntrinsic, CallableImpl::SimulatableIntrinsic(spec)) => spec, + _ => panic!("requested spec kind is not present on '{callable_name}'"), + }; + + format_exec_graph_nodes( + store, + package, + &local_names, + spec.exec_graph.select_ref(ExecGraphConfig::NoDebug), + ) +} + +fn format_exec_graph_nodes( + store: &qsc_fir::fir::PackageStore, + package: &qsc_fir::fir::Package, + local_names: &FxHashMap, + graph: &[ExecGraphNode], +) -> String { + graph + .iter() + .enumerate() + .map(|(index, node)| match node { + ExecGraphNode::Bind(pat_id) => format!("{index}: {}", bind_label(package, *pat_id)), + ExecGraphNode::Expr(expr_id) => format!( + "{index}: {}", + semantic_expr_label(store, package, local_names, *expr_id) + ), + ExecGraphNode::Jump(target) => format!("{index}: Jump({target})"), + ExecGraphNode::JumpIf(target) => format!("{index}: JumpIf({target})"), + ExecGraphNode::JumpIfNot(target) => format!("{index}: JumpIfNot({target})"), + ExecGraphNode::Ret => format!("{index}: Ret"), + ExecGraphNode::Store => format!("{index}: Store"), + ExecGraphNode::Unit => format!("{index}: Unit"), + ExecGraphNode::Debug(_) => { + unreachable!("NoDebug exec graph should not contain debug nodes") + } + }) + .collect::>() + .join("\n") +} + +fn format_store_callable_exec_graph( + store: &qsc_fir::fir::PackageStore, + store_item_id: StoreItemId, + config: ExecGraphConfig, +) -> String { + let package = store.get(store_item_id.package); + let item = package.get_item(store_item_id.item); + let ItemKind::Callable(decl) = &item.kind else { + panic!("reachable item should be callable"); + }; + let local_names = callable_local_names(package, decl); + let spec = match &decl.implementation { + CallableImpl::Spec(spec_impl) => &spec_impl.body, + CallableImpl::SimulatableIntrinsic(spec) => spec, + CallableImpl::Intrinsic => panic!("callable '{}' should have a body", decl.name.name), + }; + + format_exec_graph_nodes( + store, + package, + &local_names, + spec.exec_graph.select_ref(config), + ) +} + +fn clear_store_callable_exec_graph( + store: &mut qsc_fir::fir::PackageStore, + store_item_id: StoreItemId, +) { + let package = store.get_mut(store_item_id.package); + let item = package + .items + .get_mut(store_item_id.item) + .expect("reachable item should exist"); + let ItemKind::Callable(decl) = &mut item.kind else { + panic!("reachable item should be callable"); + }; + + match &mut decl.implementation { + CallableImpl::Spec(spec_impl) => spec_impl.body.exec_graph = Default::default(), + CallableImpl::SimulatableIntrinsic(spec) => spec.exec_graph = Default::default(), + CallableImpl::Intrinsic => panic!("callable '{}' should have a body", decl.name.name), + } +} + +fn callable_body_exec_graph_len( + store: &qsc_fir::fir::PackageStore, + store_item_id: StoreItemId, +) -> usize { + let package = store.get(store_item_id.package); + let item = package.get_item(store_item_id.item); + let ItemKind::Callable(decl) = &item.kind else { + panic!("reachable item should be callable"); + }; + + match &decl.implementation { + CallableImpl::Spec(spec_impl) => spec_impl + .body + .exec_graph + .select_ref(ExecGraphConfig::NoDebug) + .len(), + CallableImpl::SimulatableIntrinsic(spec) => { + spec.exec_graph.select_ref(ExecGraphConfig::NoDebug).len() + } + CallableImpl::Intrinsic => panic!("callable '{}' should have a body", decl.name.name), + } +} + +fn reachable_callable_names_with_packages( + store: &qsc_fir::fir::PackageStore, + pkg_id: qsc_fir::fir::PackageId, +) -> Vec { + let mut names = crate::reachability::collect_reachable_from_entry(store, pkg_id) + .into_iter() + .filter_map(|store_item_id| { + let package = store.get(store_item_id.package); + let item = package.get_item(store_item_id.item); + match &item.kind { + ItemKind::Callable(decl) => Some(format!( + "pkg={:?} {}", + store_item_id.package, decl.name.name + )), + _ => None, + } + }) + .collect::>(); + names.sort(); + names +} + +fn find_reachable_callable_by_name( + store: &qsc_fir::fir::PackageStore, + root_pkg_id: qsc_fir::fir::PackageId, + callable_name: &str, + same_package_as_root: bool, +) -> StoreItemId { + crate::reachability::collect_reachable_from_entry(store, root_pkg_id) + .into_iter() + .find(|store_item_id| { + let package = store.get(store_item_id.package); + let item = package.get_item(store_item_id.item); + matches!( + &item.kind, + ItemKind::Callable(decl) + if decl.name.name.as_ref() == callable_name + && (store_item_id.package == root_pkg_id) == same_package_as_root + ) + }) + .unwrap_or_else(|| { + panic!( + "reachable callable '{callable_name}' not found\n{}", + reachable_callable_names_with_packages(store, root_pkg_id).join("\n") + ) + }) +} + +/// Compiles Q# source through the pipeline (including exec graph rebuild) +/// and asserts the Main callable's body exec graph (`NoDebug` config) matches. +fn check_exec_graph(source: &str, expect: &Expect) { + let (store, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::ExecGraphRebuild); + let result = format_callable_exec_graph(&store, pkg_id, ExecGraphConfig::NoDebug); + expect.assert_eq(&result); +} + +fn check_callable_spec_exec_graph( + source: &str, + callable_name: &str, + spec_kind: CallableSpecKind, + expect: &Expect, +) { + let (store, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::ExecGraphRebuild); + let result = format_callable_spec_exec_graph(&store, pkg_id, callable_name, spec_kind); + expect.assert_eq(&result); +} + +#[test] +fn literal_int_emits_single_expr_node() { + check_exec_graph( + "function Main() : Int { 42 }", + &expect![[r#" + 0: Expr(ExprId(3)) [Lit(Int(42))] + 1: Ret"#]], + ); +} + +#[test] +fn binop_add_evaluates_operands_then_expr() { + check_exec_graph( + "function Main() : Int { 1 + 2 }", + &expect![[r#" + 0: Expr(ExprId(4)) [Lit(Int(1))] + 1: Store + 2: Expr(ExprId(5)) [Lit(Int(2))] + 3: Expr(ExprId(3)) [BinOp(Add)] + 4: Ret"#]], + ); +} + +#[test] +fn tuple_construction_emits_store_per_element() { + check_exec_graph( + "function Main() : (Int, Int) { (1, 2) }", + &expect![[r#" + 0: Expr(ExprId(4)) [Lit(Int(1))] + 1: Store + 2: Expr(ExprId(5)) [Lit(Int(2))] + 3: Store + 4: Expr(ExprId(3)) [Tuple(len=2)] + 5: Ret"#]], + ); +} + +#[test] +fn if_else_emits_jump_if_not_with_both_branches() { + check_exec_graph( + "function Main() : Int { if true { 1 } else { 2 } }", + &expect![[r#" + 0: Expr(ExprId(4)) [Lit(Bool(true))] + 1: JumpIfNot(4) + 2: Expr(ExprId(6)) [Lit(Int(1))] + 3: Jump(5) + 4: Expr(ExprId(8)) [Lit(Int(2))] + 5: Ret"#]], + ); +} + +#[test] +fn while_loop_emits_jump_back_to_condition() { + check_exec_graph( + "function Main() : Unit { + mutable i = 0; + while i < 3 { + i += 1; + } + }", + &expect![[r#" + 0: Expr(ExprId(3)) [Lit(Int(0))] + 1: Bind(PatId(1)) + 2: Expr(ExprId(6)) [Var] + 3: Store + 4: Expr(ExprId(7)) [Lit(Int(3))] + 5: Expr(ExprId(5)) [BinOp(Lt)] + 6: JumpIfNot(14) + 7: Expr(ExprId(9)) [Var] + 8: Store + 9: Expr(ExprId(10)) [Lit(Int(1))] + 10: Expr(ExprId(8)) [AssignOp(Add)] + 11: Unit + 12: Unit + 13: Jump(2) + 14: Unit + 15: Ret"#]], + ); +} + +#[test] +fn andl_emits_jump_if_not_for_short_circuit() { + check_exec_graph( + "function Main() : Bool { true and false }", + &expect![[r#" + 0: Expr(ExprId(4)) [Lit(Bool(true))] + 1: JumpIfNot(3) + 2: Expr(ExprId(5)) [Lit(Bool(false))] + 3: Ret"#]], + ); +} + +#[test] +fn let_binding_stores_value_then_evaluates_body() { + check_exec_graph( + "function Main() : Int { let x = 42; x }", + &expect![[r#" + 0: Expr(ExprId(3)) [Lit(Int(42))] + 1: Bind(PatId(1)) + 2: Expr(ExprId(4)) [Var] + 3: Ret"#]], + ); +} + +#[test] +fn tuple_eq_lowered_to_element_wise_andl_chain() { + // KEY TEST: classical tuple eq is now decomposed and the exec graph + // must contain the short-circuit AndL pattern instead of a single BinOp. + check_exec_graph( + "function Main() : Bool { (1, 2) == (1, 2) }", + &expect![[r#" + 0: Expr(ExprId(5)) [Lit(Int(1))] + 1: Store + 2: Expr(ExprId(8)) [Lit(Int(1))] + 3: Expr(ExprId(10)) [BinOp(Eq)] + 4: JumpIfNot(9) + 5: Expr(ExprId(6)) [Lit(Int(2))] + 6: Store + 7: Expr(ExprId(9)) [Lit(Int(2))] + 8: Expr(ExprId(11)) [BinOp(Eq)] + 9: Ret"#]], + ); +} + +#[test] +fn nested_blocks_flatten_to_sequential_nodes() { + check_exec_graph( + "function Main() : Int { let x = { let y = 1; y + 1 }; x }", + &expect![[r#" + 0: Expr(ExprId(4)) [Lit(Int(1))] + 1: Bind(PatId(2)) + 2: Expr(ExprId(6)) [Var] + 3: Store + 4: Expr(ExprId(7)) [Lit(Int(1))] + 5: Expr(ExprId(5)) [BinOp(Add)] + 6: Bind(PatId(1)) + 7: Expr(ExprId(8)) [Var] + 8: Ret"#]], + ); +} + +#[test] +fn orl_short_circuit_emits_jump_if() { + check_exec_graph( + "function Main() : Bool { true or false }", + &expect![[r#" + 0: Expr(ExprId(4)) [Lit(Bool(true))] + 1: JumpIf(3) + 2: Expr(ExprId(5)) [Lit(Bool(false))] + 3: Ret"#]], + ); +} + +#[test] +fn return_expression_emits_ret_node() { + // After return unification, `return 42;` is simplified to a trailing `42`, + // so the exec graph only contains the expression and the final Ret. + check_exec_graph( + "function Main() : Int { return 42; }", + &expect![[r#" + 0: Expr(ExprId(4)) [Lit(Int(42))] + 1: Ret"#]], + ); +} + +#[test] +fn fail_expression_evaluates_message_then_expr() { + check_exec_graph( + "function Main() : Unit { fail \"error\"; }", + &expect![[r#" + 0: Expr(ExprId(4)) [String(parts=1)] + 1: Expr(ExprId(3)) [Fail] + 2: Unit + 3: Ret"#]], + ); +} + +#[test] +fn assign_index_emits_store_and_expr_unit() { + check_exec_graph( + "function Main() : Int[] { mutable arr = [1, 2, 3]; set arr w/= 0 <- 42; arr }", + &expect![[r#" + 0: Expr(ExprId(3)) [ArrayLit(len=3)] + 1: Bind(PatId(1)) + 2: Expr(ExprId(8)) [Lit(Int(0))] + 3: Store + 4: Expr(ExprId(9)) [Lit(Int(42))] + 5: Expr(ExprId(7)) [AssignIndex] + 6: Unit + 7: Expr(ExprId(11)) [Var] + 8: Ret"#]], + ); +} + +#[test] +fn exec_graph_array_repeat_emits_store_pattern() { + check_exec_graph( + "function Main() : Int[] { let arr = [0, size = 3]; arr }", + &expect![[r#" + 0: Expr(ExprId(4)) [Lit(Int(0))] + 1: Store + 2: Expr(ExprId(5)) [Lit(Int(3))] + 3: Expr(ExprId(3)) [ArrayRepeat] + 4: Bind(PatId(1)) + 5: Expr(ExprId(6)) [Var] + 6: Ret"#]], + ); +} + +#[test] +fn exec_graph_range_expression() { + check_exec_graph( + "function Main() : Range { 0..10 }", + &expect![[r#" + 0: Expr(ExprId(4)) [Lit(Int(0))] + 1: Store + 2: Expr(ExprId(5)) [Lit(Int(10))] + 3: Expr(ExprId(3)) [Range] + 4: Ret"#]], + ); +} + +#[test] +fn exec_graph_string_interpolation() { + check_exec_graph( + r#"function Main() : String { let x = 42; $"value = {x}" }"#, + &expect![[r#" + 0: Expr(ExprId(3)) [Lit(Int(42))] + 1: Bind(PatId(1)) + 2: Expr(ExprId(5)) [Var] + 3: Store + 4: Expr(ExprId(4)) [String(parts=2)] + 5: Ret"#]], + ); +} + +#[test] +fn exec_graph_unary_not() { + check_exec_graph( + "function Main() : Bool { not true }", + &expect![[r#" + 0: Expr(ExprId(4)) [Lit(Bool(true))] + 1: Expr(ExprId(3)) [UnOp(NotL)] + 2: Ret"#]], + ); +} + +#[test] +fn exec_graph_update_index_emits_store() { + check_exec_graph( + "function Main() : Int[] { mutable arr = [1, 2, 3]; set arr w/= 0 <- 42; arr }", + &expect![[r#" + 0: Expr(ExprId(3)) [ArrayLit(len=3)] + 1: Bind(PatId(1)) + 2: Expr(ExprId(8)) [Lit(Int(0))] + 3: Store + 4: Expr(ExprId(9)) [Lit(Int(42))] + 5: Expr(ExprId(7)) [AssignIndex] + 6: Unit + 7: Expr(ExprId(11)) [Var] + 8: Ret"#]], + ); +} + +#[test] +fn exec_graph_callable_with_adjoint_spec_rebuilds_body_and_adj_independently() { + let source = "operation Foo(q : Qubit) : Unit is Adj { body ... { H(q); } adjoint ... { X(q); } } operation Main() : Unit { use q = Qubit(); Foo(q); Adjoint Foo(q); }"; + check_callable_spec_exec_graph( + source, + "Foo", + CallableSpecKind::Body, + &expect![[r#" + 0: H + 1: Store + 2: Var(q) + 3: Call + 4: Unit + 5: Ret"#]], + ); + check_callable_spec_exec_graph( + source, + "Foo", + CallableSpecKind::Adj, + &expect![[r#" + 0: X + 1: Store + 2: Var(q) + 3: Call + 4: Unit + 5: Ret"#]], + ); +} + +#[test] +fn controlled_spec_exec_graph_rebuilds_semantic_order() { + check_callable_spec_exec_graph( + "operation Foo(q : Qubit) : Unit is Ctl { + body ... { X(q); } + controlled (ctls, ...) { Controlled X(ctls, q); } + } + operation Main() : Unit { + use ctl = Qubit(); + use q = Qubit(); + Controlled Foo([ctl], q); + }", + "Foo", + CallableSpecKind::Ctl, + &expect![[r#" + 0: X + 1: Functor(Ctl)(X) + 2: Store + 3: Var(ctls) + 4: Store + 5: Var(q) + 6: Store + 7: Tuple(len=2) + 8: Call + 9: Unit + 10: Ret"#]], + ); +} + +#[test] +fn controlled_adjoint_spec_exec_graph_rebuilds_semantic_order() { + check_callable_spec_exec_graph( + "operation Foo(q : Qubit) : Unit is Adj + Ctl { + body ... { S(q); } + adjoint ... { Adjoint S(q); } + controlled (ctls, ...) { Controlled S(ctls, q); } + controlled adjoint (ctls, ...) { Controlled Adjoint S(ctls, q); } + } + operation Main() : Unit { + use ctl = Qubit(); + use q = Qubit(); + Controlled Adjoint Foo([ctl], q); + }", + "Foo", + CallableSpecKind::CtlAdj, + &expect![[r#" + 0: S + 1: Functor(Adj)(S) + 2: Functor(Ctl)(Functor(Adj)(S)) + 3: Store + 4: Var(ctls) + 5: Store + 6: Var(q) + 7: Store + 8: Tuple(len=2) + 9: Call + 10: Unit + 11: Ret"#]], + ); +} + +#[test] +fn simulatable_intrinsic_spec_exec_graph_rebuilds_semantic_order() { + check_callable_spec_exec_graph( + "@SimulatableIntrinsic() + operation MyMeasurement(q : Qubit) : Result { + H(q); + M(q) + } + @EntryPoint() + operation Main() : Result { + use q = Qubit(); + MyMeasurement(q) + }", + "MyMeasurement", + CallableSpecKind::SimulatableIntrinsic, + &expect![[r#" + 0: H + 1: Store + 2: Var(q) + 3: Call + 4: M + 5: Store + 6: Var(q) + 7: Call + 8: Ret"#]], + ); +} + +#[test] +fn exec_graph_entry_expression_rebuilt_correctly() { + check_exec_graph( + "function Main() : Int { let x = 1 + 2; let y = x * 3; y }", + &expect![[r#" + 0: Expr(ExprId(4)) [Lit(Int(1))] + 1: Store + 2: Expr(ExprId(5)) [Lit(Int(2))] + 3: Expr(ExprId(3)) [BinOp(Add)] + 4: Bind(PatId(1)) + 5: Expr(ExprId(7)) [Var] + 6: Store + 7: Expr(ExprId(8)) [Lit(Int(3))] + 8: Expr(ExprId(6)) [BinOp(Mul)] + 9: Bind(PatId(2)) + 10: Expr(ExprId(9)) [Var] + 11: Ret"#]], + ); +} + +#[test] +fn exec_graph_rebuild_is_idempotent() { + let (mut store, pkg_id) = compile_and_run_pipeline_to( + "function Main() : Int { let x = 1 + 2; x }", + PipelineStage::ExecGraphRebuild, + ); + let first = format_callable_exec_graph(&store, pkg_id, ExecGraphConfig::NoDebug); + + // Run rebuild a second time — the result must be identical. + super::rebuild_exec_graphs(&mut store, pkg_id, &[]); + let second = format_callable_exec_graph(&store, pkg_id, ExecGraphConfig::NoDebug); + + assert_eq!(first, second, "exec graph rebuild is not idempotent"); +} + +#[test] +fn reachable_cross_package_callables_keep_existing_exec_graphs_while_local_specializations_rebuild() +{ + let source = r#" + open Std.Arrays; + open Std.Math; + + @EntryPoint() + operation Main() : Unit { + let arr = [-1, 2, -3]; + let _ = Mapped(AbsI, arr); + } + "#; + + let (mut store, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::ExecGraphRebuild); + + let local_specialization = + find_reachable_callable_by_name(&store, pkg_id, "Mapped{AbsI}", true); + let cross_package_callable = find_reachable_callable_by_name(&store, pkg_id, "AbsI", false); + + assert_eq!(local_specialization.package, pkg_id); + assert_ne!(cross_package_callable.package, pkg_id); + + let expected_local_graph = + format_store_callable_exec_graph(&store, local_specialization, ExecGraphConfig::NoDebug); + let expected_cross_graph = + format_store_callable_exec_graph(&store, cross_package_callable, ExecGraphConfig::NoDebug); + + assert!( + !expected_local_graph.is_empty(), + "local specialization should have a rebuilt exec graph" + ); + assert!( + !expected_cross_graph.is_empty(), + "reachable cross-package callable should start with a lowered exec graph" + ); + + clear_store_callable_exec_graph(&mut store, local_specialization); + clear_store_callable_exec_graph(&mut store, cross_package_callable); + + assert_eq!( + callable_body_exec_graph_len(&store, local_specialization), + 0 + ); + assert_eq!( + callable_body_exec_graph_len(&store, cross_package_callable), + 0 + ); + + super::rebuild_exec_graphs(&mut store, pkg_id, &[]); + + assert_eq!( + format_store_callable_exec_graph(&store, local_specialization, ExecGraphConfig::NoDebug), + expected_local_graph, + "reachable local specialization should be rebuilt" + ); + assert_eq!( + callable_body_exec_graph_len(&store, cross_package_callable), + 0, + "reachable cross-package callable should not be rebuilt" + ); +} + +#[test] +fn exec_graph_rebuild_preserves_invariants() { + let source = indoc! {" + namespace Test { + @EntryPoint() + operation Main() : Unit { + use q = Qubit(); + H(q); + Reset(q); + } + } + "}; + let (store, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::ExecGraphRebuild); + crate::invariants::check(&store, pkg_id, crate::invariants::InvariantLevel::PostAll); +} + +#[test] +#[should_panic(expected = "Struct expressions should have been eliminated by udt_erase")] +fn exec_graph_rebuild_rejects_struct_expressions() { + // Feed FIR that still contains ExprKind::Struct (pipeline stopped + // before udt_erase) to exec_graph_rebuild. The pass should panic + // because struct expressions must be erased before exec graph rebuild. + let source = indoc! {" + namespace Test { + struct Pair { X : Int, Y : Int } + @EntryPoint() + function Main() : (Int, Int) { + let p = new Pair { X = 1, Y = 2 }; + (p.X, p.Y) + } + } + "}; + let (mut store, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::Defunc); + super::rebuild_exec_graphs(&mut store, pkg_id, &[]); +} + +#[test] +fn pinned_item_rebuilt_in_exec_graph() { + // After full pipeline with pinned items, verify the pinned callable has + // non-empty exec graph nodes — proving it participates in exec graph rebuild. + use crate::test_utils::compile_to_fir; + + let (mut store, pkg_id) = compile_to_fir(indoc! {" + namespace Test { + @EntryPoint() + operation Main() : Int { 42 } + // Unreachable from entry but will be pinned + operation Pinned() : Int { 99 } + } + "}); + let package = store.get(pkg_id); + let pinned_local = package + .items + .iter() + .find_map(|(item_id, item)| match &item.kind { + ItemKind::Callable(decl) if decl.name.name.as_ref() == "Pinned" => Some(item_id), + _ => None, + }) + .expect("Pinned callable should exist"); + let pinned_store_id = StoreItemId { + package: pkg_id, + item: pinned_local, + }; + + let errors = crate::run_pipeline_to( + &mut store, + pkg_id, + PipelineStage::ExecGraphRebuild, + &[pinned_store_id], + ); + assert!(errors.is_empty(), "pipeline errors: {errors:?}"); + + // Verify the pinned callable's spec has a non-empty exec graph. + let package = store.get(pkg_id); + let item = package.get_item(pinned_local); + if let ItemKind::Callable(decl) = &item.kind { + if let CallableImpl::Spec(spec) = &decl.implementation { + let graph = spec + .body + .exec_graph + .select_ref(qsc_fir::fir::ExecGraphConfig::NoDebug); + assert!( + !graph.is_empty(), + "pinned callable should have non-empty exec graph after rebuild" + ); + } else { + panic!("pinned callable should have Spec implementation"); + } + } else { + panic!("pinned item should be a callable"); + } +} diff --git a/source/compiler/qsc_fir_transforms/src/fir_builder.rs b/source/compiler/qsc_fir_transforms/src/fir_builder.rs new file mode 100644 index 0000000000..43f40d73d5 --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/fir_builder.rs @@ -0,0 +1,433 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! Shared FIR node allocation helpers. +//! +//! Every transform pass that synthesizes new FIR nodes must: +//! - Allocate a fresh ID from the pipeline-global [`Assigner`]. +//! - Insert the node into the package's arena. +//! - Attach [`EMPTY_EXEC_RANGE`](crate::EMPTY_EXEC_RANGE) for `Expr` and +//! `Stmt` nodes so the final [`exec_graph_rebuild`](crate::exec_graph_rebuild) +//! pass can replace them with correct ranges. +//! +//! This module provides composable helpers that encapsulate this pattern, +//! reducing boilerplate across passes and centralizing the +//! `EMPTY_EXEC_RANGE` convention. + +use crate::EMPTY_EXEC_RANGE; +use qsc_data_structures::span::Span; +use qsc_fir::assigner::Assigner; +use qsc_fir::fir::{ + BinOp, Block, BlockId, CallableDecl, Expr, ExprId, ExprKind, Field, FieldPath, Ident, ItemId, + ItemKind, LocalItemId, LocalVarId, Mutability, Package, PackageId, PackageLookup, PackageStore, + Pat, PatId, PatKind, Res, SpecDecl, SpecImpl, Stmt, StmtId, StmtKind, StoreItemId, UnOp, +}; +use rustc_hash::FxHashSet; + +use qsc_fir::ty::{Prim, Ty}; +use std::rc::Rc; + +/// Allocates an `Expr` with the given kind and inserts it into the package. +pub(crate) fn alloc_expr( + package: &mut Package, + assigner: &mut Assigner, + ty: Ty, + kind: ExprKind, + span: Span, +) -> ExprId { + let id = assigner.next_expr(); + package.exprs.insert( + id, + Expr { + id, + span, + ty, + kind, + exec_graph_range: EMPTY_EXEC_RANGE, + }, + ); + id +} + +/// Allocates a `Var(Res::Local(var_id))` expression. +pub(crate) fn alloc_local_var_expr( + package: &mut Package, + assigner: &mut Assigner, + var_id: LocalVarId, + ty: Ty, + span: Span, +) -> ExprId { + alloc_expr( + package, + assigner, + ty, + ExprKind::Var(Res::Local(var_id), Vec::new()), + span, + ) +} + +/// Allocates a `Field(record, Path([index]))` expression. +pub(crate) fn alloc_field_expr( + package: &mut Package, + assigner: &mut Assigner, + record_id: ExprId, + index: usize, + ty: Ty, + span: Span, +) -> ExprId { + alloc_expr( + package, + assigner, + ty, + ExprKind::Field( + record_id, + Field::Path(FieldPath { + indices: vec![index], + }), + ), + span, + ) +} + +/// Allocates a `BinOp(op, lhs, rhs)` expression. +pub(crate) fn alloc_bin_op_expr( + package: &mut Package, + assigner: &mut Assigner, + op: BinOp, + lhs: ExprId, + rhs: ExprId, + ty: Ty, + span: Span, +) -> ExprId { + alloc_expr(package, assigner, ty, ExprKind::BinOp(op, lhs, rhs), span) +} + +/// Allocates a `UnOp(NotL, operand)` expression with `Bool` type. +pub(crate) fn alloc_not_expr( + package: &mut Package, + assigner: &mut Assigner, + operand: ExprId, + span: Span, +) -> ExprId { + alloc_expr( + package, + assigner, + Ty::Prim(Prim::Bool), + ExprKind::UnOp(UnOp::NotL, operand), + span, + ) +} + +/// Allocates an `If(cond, then, else)` expression. +pub(crate) fn alloc_if_expr( + package: &mut Package, + assigner: &mut Assigner, + cond: ExprId, + then_expr: ExprId, + else_expr: Option, + ty: Ty, + span: Span, +) -> ExprId { + alloc_expr( + package, + assigner, + ty, + ExprKind::If(cond, then_expr, else_expr), + span, + ) +} + +/// Allocates a `Block(block_id)` expression. +pub(crate) fn alloc_block_expr( + package: &mut Package, + assigner: &mut Assigner, + block_id: BlockId, + ty: Ty, + span: Span, +) -> ExprId { + alloc_expr(package, assigner, ty, ExprKind::Block(block_id), span) +} + +/// Allocates an `Assign(lhs, rhs)` expression with Unit type. +pub(crate) fn alloc_assign_expr( + package: &mut Package, + assigner: &mut Assigner, + lhs: ExprId, + rhs: ExprId, + span: Span, +) -> ExprId { + alloc_expr( + package, + assigner, + Ty::UNIT, + ExprKind::Assign(lhs, rhs), + span, + ) +} + +/// Allocates a boolean literal expression. +pub(crate) fn alloc_bool_lit( + package: &mut Package, + assigner: &mut Assigner, + value: bool, + span: Span, +) -> ExprId { + alloc_expr( + package, + assigner, + Ty::Prim(Prim::Bool), + ExprKind::Lit(qsc_fir::fir::Lit::Bool(value)), + span, + ) +} + +/// Allocates a Unit `()` expression. +pub(crate) fn alloc_unit_expr( + package: &mut Package, + assigner: &mut Assigner, + span: Span, +) -> ExprId { + alloc_expr( + package, + assigner, + Ty::UNIT, + ExprKind::Tuple(Vec::new()), + span, + ) +} + +/// Allocates a `Tuple(exprs)` expression. +#[allow(dead_code)] +pub(crate) fn alloc_tuple_expr( + package: &mut Package, + assigner: &mut Assigner, + exprs: Vec, + ty: Ty, + span: Span, +) -> ExprId { + alloc_expr(package, assigner, ty, ExprKind::Tuple(exprs), span) +} + +/// Allocates a `Stmt` with the given kind and inserts it into the package. +pub(crate) fn alloc_stmt( + package: &mut Package, + assigner: &mut Assigner, + kind: StmtKind, + span: Span, +) -> StmtId { + let id = assigner.next_stmt(); + package.stmts.insert( + id, + Stmt { + id, + span, + kind, + exec_graph_range: EMPTY_EXEC_RANGE, + }, + ); + id +} + +/// Allocates an `Expr` statement (trailing expression, no semicolon). +pub(crate) fn alloc_expr_stmt( + package: &mut Package, + assigner: &mut Assigner, + expr_id: ExprId, + span: Span, +) -> StmtId { + alloc_stmt(package, assigner, StmtKind::Expr(expr_id), span) +} + +/// Allocates a `Semi` statement (expression with trailing semicolon). +pub(crate) fn alloc_semi_stmt( + package: &mut Package, + assigner: &mut Assigner, + expr_id: ExprId, + span: Span, +) -> StmtId { + alloc_stmt(package, assigner, StmtKind::Semi(expr_id), span) +} + +/// Allocates a `Local` statement (variable declaration). +pub(crate) fn alloc_local_stmt( + package: &mut Package, + assigner: &mut Assigner, + mutability: Mutability, + pat_id: PatId, + init_expr: ExprId, + span: Span, +) -> StmtId { + alloc_stmt( + package, + assigner, + StmtKind::Local(mutability, pat_id, init_expr), + span, + ) +} + +/// Allocates a `Block` and inserts it into the package. +pub(crate) fn alloc_block( + package: &mut Package, + assigner: &mut Assigner, + stmts: Vec, + ty: Ty, + span: Span, +) -> BlockId { + let id = assigner.next_block(); + package.blocks.insert( + id, + Block { + id, + span, + ty, + stmts, + }, + ); + id +} + +/// Allocates a `Pat` with `PatKind::Bind` and inserts it into the package. +pub(crate) fn alloc_bind_pat( + package: &mut Package, + assigner: &mut Assigner, + name: &str, + ty: Ty, + span: Span, +) -> (LocalVarId, PatId) { + let local_id = assigner.next_local(); + let pat_id = assigner.next_pat(); + package.pats.insert( + pat_id, + Pat { + id: pat_id, + span, + ty, + kind: PatKind::Bind(Ident { + id: local_id, + span, + name: Rc::from(name), + }), + }, + ); + (local_id, pat_id) +} + +/// Creates a local variable declaration and returns its `(LocalVarId, StmtId)`. +/// +/// Combines [`alloc_bind_pat`] + [`alloc_local_stmt`]. +pub(crate) fn alloc_local_var( + package: &mut Package, + assigner: &mut Assigner, + name: &str, + ty: &Ty, + init_expr: ExprId, + mutability: Mutability, +) -> (LocalVarId, StmtId) { + let (local_id, pat_id) = alloc_bind_pat(package, assigner, name, ty.clone(), Span::default()); + let stmt_id = alloc_local_stmt( + package, + assigner, + mutability, + pat_id, + init_expr, + Span::default(), + ); + (local_id, stmt_id) +} + +/// Resolves a `Ty::Udt(Res::Item(item_id))` to its constituent field types +/// via `get_pure_ty()`. Returns `None` for single-field UDTs or non-UDT items. +pub(crate) fn resolve_udt_element_types(store: &PackageStore, item_id: &ItemId) -> Option> { + let package = store.get(item_id.package); + let item = package.get_item(item_id.item); + if let ItemKind::Ty(_, udt) = &item.kind { + match udt.get_pure_ty() { + Ty::Tuple(elems) if !elems.is_empty() => Some(elems), + _ => None, + } + } else { + None + } +} + +/// Decomposes a `PatKind::Bind` pattern into a `PatKind::Tuple` of per-element +/// bindings. +/// +/// Allocates `n` new `LocalVarId`/`PatId` pairs (where `n = elem_types.len()`), +/// each named `{name}_{i}`, and rewrites the original pattern to +/// `PatKind::Tuple(new_pat_ids)`. +/// +/// Returns the newly allocated local variable IDs. +pub(crate) fn decompose_binding( + package: &mut Package, + assigner: &mut Assigner, + pat_id: PatId, + name: &str, + elem_types: &[Ty], +) -> Vec { + let n = elem_types.len(); + let mut new_locals: Vec = Vec::with_capacity(n); + let mut new_pat_ids: Vec = Vec::with_capacity(n); + + for (i, elem_ty) in elem_types.iter().enumerate() { + let new_local = assigner.next_local(); + new_locals.push(new_local); + + let new_pat_id = assigner.next_pat(); + let elem_name: Rc = Rc::from(format!("{name}_{i}")); + let new_pat = Pat { + id: new_pat_id, + span: Span::default(), + ty: elem_ty.clone(), + kind: PatKind::Bind(Ident { + id: new_local, + span: Span::default(), + name: elem_name, + }), + }; + package.pats.insert(new_pat_id, new_pat); + new_pat_ids.push(new_pat_id); + } + + // Rewrite the original binding pattern in-place. + let pat = package + .pats + .get_mut(pat_id) + .expect("candidate pat should exist"); + pat.kind = PatKind::Tuple(new_pat_ids); + + new_locals +} + +/// Returns an iterator-like collection of `(LocalItemId, &CallableDecl)` for +/// every reachable callable that belongs to the given package. +/// +/// Filters `reachable` to items in `package_id` that are `ItemKind::Callable`. +pub(crate) fn reachable_local_callables<'a>( + package: &'a Package, + package_id: PackageId, + reachable: &'a FxHashSet, +) -> impl Iterator { + reachable.iter().filter_map(move |item_id| { + if item_id.package != package_id { + return None; + } + let item = package.get_item(item_id.item); + match &item.kind { + ItemKind::Callable(decl) => Some((item_id.item, decl.as_ref())), + _ => None, + } + }) +} + +/// Returns an iterator over the functored specializations (`adj`, `ctl`, `ctl_adj`) +/// of a `SpecImpl`, skipping `None` entries. +pub(crate) fn functored_specs(spec_impl: &SpecImpl) -> impl Iterator { + [ + spec_impl.adj.as_ref(), + spec_impl.ctl.as_ref(), + spec_impl.ctl_adj.as_ref(), + ] + .into_iter() + .flatten() +} diff --git a/source/compiler/qsc_fir_transforms/src/gc_unreachable.rs b/source/compiler/qsc_fir_transforms/src/gc_unreachable.rs new file mode 100644 index 0000000000..e252bb0691 --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/gc_unreachable.rs @@ -0,0 +1,177 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! FIR arena garbage collection. +//! +//! Removes unreachable (orphaned) blocks, stmts, exprs, and pats from +//! a package's [`IndexMap`](qsc_data_structures::index_map::IndexMap) arenas +//! by tombstoning entries that are not reachable from any callable spec body +//! or the package entry expression. +//! +//! # When to run +//! +//! After all FIR transforms that create/orphan arena nodes have completed +//! and before [`exec_graph_rebuild`](crate::exec_graph_rebuild) reconstructs +//! execution graphs from the surviving FIR tree. +//! +//! # Correctness contract +//! +//! The sweep phase tombstones complete unreachable subgraphs: if a node is +//! unreachable, all of its descendants are also unreachable (because the +//! only paths to descendants go through ancestors). The mark phase records +//! every node it visits via the [`Visitor`] trait. The combination guarantees +//! that no surviving node references a tombstoned node, so +//! [`PackageLookup::get_*(..)`](qsc_fir::fir::PackageLookup) calls remain +//! safe. +//! +//! # Transformation shape +//! +//! **Before:** Package arenas contain orphaned blocks, stmts, exprs, and pats +//! left behind by earlier rewrite passes (return unify, defunctionalize, UDT +//! erase, SROA, argument promote). +//! +//! **After:** Only nodes reachable from callable bodies and the entry +//! expression survive. Orphaned entries are tombstoned in the `IndexMap`. + +#[cfg(test)] +mod tests; + +use qsc_fir::fir::{ + Block, BlockId, Expr, ExprId, Package, PackageLookup, Pat, PatId, Stmt, StmtId, +}; +use qsc_fir::visit::{self, Visitor}; +use rustc_hash::FxHashSet; + +/// Tombstones unreachable blocks, stmts, exprs, and pats in the package's +/// `IndexMap` arenas. Returns the total number of entries removed. +/// +/// "Unreachable" means: not visited by a [`Visitor`] walk starting from +/// every item in `package.items` and the `package.entry` expression. +/// Items themselves are never removed. +/// +/// # When to call +/// +/// After all FIR transforms that create or orphan arena nodes, and before +/// `exec_graph_rebuild`. +pub fn gc_unreachable(package: &mut Package) -> usize { + let live = mark(package); + sweep(package, &live) +} + +/// Reachable-ID sets for each arena type. +#[derive(Debug, Default)] +struct LiveSets { + blocks: FxHashSet, + stmts: FxHashSet, + exprs: FxHashSet, + pats: FxHashSet, +} + +fn mark(package: &Package) -> LiveSets { + let mut collector = ReachabilityCollector { + package, + live: LiveSets::default(), + }; + + // Walk all items (callable spec bodies, including unreachable callables — + // item-level DCE is a separate concern). This ensures every spec body's + // nodes are marked live. + for (_, item) in &package.items { + collector.visit_item(item); + } + + // Walk the entry expression tree (may reference nodes not reachable from + // any callable spec body, e.g. top-level let bindings in the entry block). + if let Some(entry_expr_id) = package.entry { + collector.visit_expr(entry_expr_id); + } + + collector.live +} + +struct ReachabilityCollector<'a> { + package: &'a Package, + live: LiveSets, +} + +impl<'a> Visitor<'a> for ReachabilityCollector<'a> { + fn get_block(&self, id: BlockId) -> &'a Block { + self.package.get_block(id) + } + + fn get_expr(&self, id: ExprId) -> &'a Expr { + self.package.get_expr(id) + } + + fn get_pat(&self, id: PatId) -> &'a Pat { + self.package.get_pat(id) + } + + fn get_stmt(&self, id: StmtId) -> &'a Stmt { + self.package.get_stmt(id) + } + + fn visit_block(&mut self, id: BlockId) { + if self.live.blocks.insert(id) { + visit::walk_block(self, id); + } + } + + fn visit_stmt(&mut self, id: StmtId) { + if self.live.stmts.insert(id) { + visit::walk_stmt(self, id); + } + } + + fn visit_expr(&mut self, id: ExprId) { + if self.live.exprs.insert(id) { + visit::walk_expr(self, id); + } + } + + fn visit_pat(&mut self, id: PatId) { + if self.live.pats.insert(id) { + visit::walk_pat(self, id); + } + } +} + +/// Deletes every arena node that was not marked live during `mark`. +/// +/// Before, dead blocks, statements, expressions, and patterns still occupy the +/// package arenas and can keep stale ids addressable. After, only the nodes in +/// `live` remain and the returned count records how many entries were purged. +fn sweep(package: &mut Package, live: &LiveSets) -> usize { + let mut removed = 0; + + package.blocks.retain(|id, _| { + let keep = live.blocks.contains(&id); + if !keep { + removed += 1; + } + keep + }); + package.stmts.retain(|id, _| { + let keep = live.stmts.contains(&id); + if !keep { + removed += 1; + } + keep + }); + package.exprs.retain(|id, _| { + let keep = live.exprs.contains(&id); + if !keep { + removed += 1; + } + keep + }); + package.pats.retain(|id, _| { + let keep = live.pats.contains(&id); + if !keep { + removed += 1; + } + keep + }); + + removed +} diff --git a/source/compiler/qsc_fir_transforms/src/gc_unreachable/tests.rs b/source/compiler/qsc_fir_transforms/src/gc_unreachable/tests.rs new file mode 100644 index 0000000000..101de097e4 --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/gc_unreachable/tests.rs @@ -0,0 +1,224 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +// Proptest applicability: Low — gc_unreachable operates on FIR arena nodes (mark-and-sweep), +// not on Q# semantics. Its correctness is a structural invariant (no surviving node references +// a tombstoned node) rather than behavioral equivalence. Q# template generation doesn't add +// much beyond targeted snapshots that create known orphan patterns. + +use crate::PipelineStage; +use crate::test_utils::compile_and_run_pipeline_to; +use expect_test::{Expect, expect}; +use indoc::indoc; + +/// Counts total live entries across all four arena types. +fn arena_live_count(package: &qsc_fir::fir::Package) -> usize { + package.blocks.iter().count() + + package.stmts.iter().count() + + package.exprs.iter().count() + + package.pats.iter().count() +} + +#[test] +fn gc_no_orphans_preserves_all_entries() { + // A simple program with one operation, no closures, no multiple returns. + // After arg_promote, there should be no orphans. + let source = indoc! {" + namespace Test { + @EntryPoint() + operation Main() : Unit { + use q = Qubit(); + H(q); + Reset(q); + } + } + "}; + let (mut store, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::ArgPromote); + let before = arena_live_count(store.get(pkg_id)); + let removed = super::gc_unreachable(store.get_mut(pkg_id)); + let after = arena_live_count(store.get(pkg_id)); + assert_eq!(removed, 0, "simple program should have no orphans"); + assert_eq!(before, after, "arena sizes should be unchanged"); +} + +#[test] +fn gc_removes_return_unify_orphans() { + // A program with multiple return paths triggers return_unify rewrites, + // which leaves the original return-path stmts/exprs as orphans. + let source = indoc! {" + namespace Test { + @EntryPoint() + function Main() : Int { + if true { + return 1; + } + return 2; + } + } + "}; + let (mut store, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::ArgPromote); + let removed = super::gc_unreachable(store.get_mut(pkg_id)); + assert!( + removed > 0, + "return_unify should leave orphans that GC removes" + ); + // Verify post-GC integrity (PostArgPromote: checks arena links without + // requiring exec_graph_rebuild to have run). + crate::invariants::check( + &store, + pkg_id, + crate::invariants::InvariantLevel::PostArgPromote, + ); +} + +#[test] +fn gc_removes_defunc_orphans() { + // A program with closures triggers defunctionalization body cloning, + // which leaves original closure bodies as orphans. + let source = indoc! {" + namespace Test { + function Apply(f : Int -> Int, x : Int) : Int { f(x) } + @EntryPoint() + function Main() : Int { Apply(x -> x + 1, 5) } + } + "}; + let (mut store, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::ArgPromote); + let removed = super::gc_unreachable(store.get_mut(pkg_id)); + assert!(removed > 0, "defunc should leave orphans that GC removes"); + // Verify post-GC integrity (PostArgPromote: checks arena links without + // requiring exec_graph_rebuild to have run). + crate::invariants::check( + &store, + pkg_id, + crate::invariants::InvariantLevel::PostArgPromote, + ); +} + +#[test] +fn gc_then_check_id_references_passes() { + // A non-trivial program exercising multiple transform passes. + // After GC, check_id_references (via PostAll invariants) should not panic. + let source = indoc! {" + namespace Test { + operation ApplyIfOne(q : Qubit, op : Qubit => Unit) : Unit { + op(q); + } + @EntryPoint() + operation Main() : Unit { + use q = Qubit(); + ApplyIfOne(q, H); + if M(q) == One { + X(q); + } + Reset(q); + } + } + "}; + // Run full pipeline — this runs GC then PostAll invariants (including check_id_references). + let (_store, _pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::Full); + // If we reach here, check_id_references passed post-GC. +} + +#[test] +fn gc_on_entry_less_package_is_noop() { + // Compile a source with entry, then target the core package (no entry). + let source = indoc! {" + namespace Test { + @EntryPoint() + operation Main() : Unit {} + } + "}; + let (mut store, _pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::ArgPromote); + let core_id = qsc_fir::fir::PackageId::CORE; + assert!( + store.get(core_id).entry.is_none(), + "core package should have no entry expression" + ); + let removed = super::gc_unreachable(store.get_mut(core_id)); + assert_eq!(removed, 0, "entry-less core package should have no orphans"); +} + +#[test] +fn gc_is_idempotent() { + // Multiple return paths leave orphaned arena nodes after return_unify. + let source = indoc! {" + namespace Test { + @EntryPoint() + function Main() : Int { + if true { + return 1; + } + return 2; + } + } + "}; + let (mut store, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::ArgPromote); + let first_pass = super::gc_unreachable(store.get_mut(pkg_id)); + assert!(first_pass > 0, "first GC pass should remove orphans"); + let second_pass = super::gc_unreachable(store.get_mut(pkg_id)); + assert_eq!( + second_pass, 0, + "second GC pass should find nothing to remove" + ); +} + +fn render_before_after_gc(source: &str) -> (String, String) { + let (mut store, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::ArgPromote); + let before = crate::pretty::write_package_qsharp(&store, pkg_id); + super::gc_unreachable(store.get_mut(pkg_id)); + let after = crate::pretty::write_package_qsharp(&store, pkg_id); + (before, after) +} + +fn check_before_after_gc(source: &str, expect: &Expect) { + let (before, after) = render_before_after_gc(source); + expect.assert_eq(&format!("BEFORE:\n{before}\nAFTER:\n{after}")); +} + +#[test] +fn before_after_gc_removes_orphans() { + check_before_after_gc( + indoc! {" + namespace Test { + @EntryPoint() + function Main() : Int { + if true { + return 1; + } + return 2; + } + } + "}, + &expect![[r#" + BEFORE: + // namespace Test + function Main() : Int { + body { + if true { + 1 + } else { + 2 + } + + } + } + // entry + Main() + + AFTER: + // namespace Test + function Main() : Int { + body { + if true { + 1 + } else { + 2 + } + + } + } + // entry + Main() + "#]], // snapshot populated by UPDATE_EXPECT=1 + ); +} diff --git a/source/compiler/qsc_fir_transforms/src/invariants.rs b/source/compiler/qsc_fir_transforms/src/invariants.rs new file mode 100644 index 0000000000..a74a671dc6 --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/invariants.rs @@ -0,0 +1,1591 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! FIR structural invariant checker. +//! +//! Verifies that the FIR is well-formed after each transformation pass. +//! Different invariant levels check progressively stronger properties as more +//! passes have been applied. +//! +//! [`InvariantLevel`] variants correspond to pipeline stages in order: +//! +//! | Variant | Checked after | +//! |---|---| +//! | `PostMono` | Monomorphization — no `Ty::Param` in reachable code. | +//! | `PostReturnUnify` | Return unification — no `ExprKind::Return`. | +//! | `PostDefunc` | Defunctionalization — no `Ty::Arrow` / closures. | +//! | `PostUdtErase` | UDT erasure — no `Ty::Udt` / struct exprs. | +//! | `PostTupleCompLower` | Tuple comparison lowering. | +//! | `PostSroa` | SROA — tuple decomposition patterns match types. | +//! | `PostArgPromote` | Argument promotion — input patterns match types. | +//! | `PostGc` | Unreachable GC — no orphaned arena node references. | +//! | `PostAll` | All passes — full structural + type checks. | +//! +#[cfg(test)] +mod tests; + +use crate::fir_builder::functored_specs; +use qsc_fir::fir::{ + BinOp, BlockId, CallableDecl, CallableImpl, ExecGraphConfig, ExecGraphDebugNode, ExecGraphNode, + ExprId, ExprKind, Field, ItemKind, LocalItemId, LocalVarId, Package, PackageId, PackageLookup, + PackageStore, PatId, PatKind, Res, SpecDecl, StmtKind, StoreItemId, +}; +use qsc_fir::ty::{FunctorSet, Ty}; +use rustc_hash::{FxHashMap, FxHashSet}; + +use crate::reachability::{collect_reachable_from_entry, collect_reachable_package_closure}; + +/// The level of invariant checking to perform, corresponding to which passes +/// have already been applied. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum InvariantLevel { + /// After monomorphization: no `Ty::Param` in reachable code. + PostMono, + /// After return unification: additionally no `ExprKind::Return` in reachable code. + PostReturnUnify, + /// After defunctionalization: additionally no `Ty::Arrow` params and no + /// `ExprKind::Closure` in reachable code. + PostDefunc, + /// After UDT erasure: additionally no `Ty::Udt`, no + /// `ExprKind::Struct`, and no `Field::Path` in `UpdateField`/`AssignField`. + PostUdtErase, + /// After tuple comparison lowering: additionally no `BinOp(Eq/Neq)` on + /// tuple-typed operands. + PostTupleCompLower, + /// After SROA: additionally synthesized local tuple patterns must match + /// the tuple types they decompose. + PostSroa, + /// After argument promotion: additionally synthesized callable input tuple + /// patterns must match the callable input types they decompose. + PostArgPromote, + /// After unreachable GC: no orphaned arena node references survive in the + /// live FIR tree. Inherits all [`PostArgPromote`](Self::PostArgPromote) + /// checks. + PostGc, + /// After all passes: all structural checks plus per-pass type constraints. + PostAll, +} + +impl InvariantLevel { + /// Returns `true` when this level is at or after monomorphization. + fn is_post_mono_or_later(self) -> bool { + matches!( + self, + Self::PostMono + | Self::PostReturnUnify + | Self::PostDefunc + | Self::PostUdtErase + | Self::PostTupleCompLower + | Self::PostSroa + | Self::PostArgPromote + | Self::PostGc + | Self::PostAll + ) + } + + /// Returns `true` when this level is at or after return unification. + fn is_post_return_unify_or_later(self) -> bool { + matches!( + self, + Self::PostReturnUnify + | Self::PostDefunc + | Self::PostUdtErase + | Self::PostTupleCompLower + | Self::PostSroa + | Self::PostArgPromote + | Self::PostGc + | Self::PostAll + ) + } + + /// Returns `true` when this level is at or after defunctionalization. + fn is_post_defunc_or_later(self) -> bool { + matches!( + self, + Self::PostDefunc + | Self::PostUdtErase + | Self::PostTupleCompLower + | Self::PostSroa + | Self::PostArgPromote + | Self::PostGc + | Self::PostAll + ) + } + + /// Returns `true` when this level is at or after UDT erasure. + fn is_post_udt_erase_or_later(self) -> bool { + matches!( + self, + Self::PostUdtErase + | Self::PostTupleCompLower + | Self::PostSroa + | Self::PostArgPromote + | Self::PostGc + | Self::PostAll + ) + } + + /// Returns `true` when this level is at or after tuple comparison lowering. + fn is_post_tuple_comp_lower_or_later(self) -> bool { + matches!( + self, + Self::PostTupleCompLower + | Self::PostSroa + | Self::PostArgPromote + | Self::PostGc + | Self::PostAll + ) + } + + /// Returns `true` when this level is at or after SROA. + fn is_post_sroa_or_later(self) -> bool { + matches!( + self, + Self::PostSroa | Self::PostArgPromote | Self::PostGc | Self::PostAll + ) + } + + /// Returns `true` when this level is at or after argument promotion. + fn is_post_arg_promote_or_later(self) -> bool { + matches!(self, Self::PostArgPromote | Self::PostGc | Self::PostAll) + } +} + +/// Checks FIR structural invariants on entry-reachable code. +/// +/// The invariant walk is scoped to items reachable from the target package's +/// entry expression. Items pinned for backend codegen (e.g. for +/// `fir_to_qir_from_callable`) are excluded from this check — the production +/// pipeline intentionally limits invariant enforcement to the entry-rooted +/// reachability closure. +/// +/// # Panics +/// +/// Panics with a descriptive message if any invariant is violated. +pub fn check(store: &PackageStore, package_id: qsc_fir::fir::PackageId, level: InvariantLevel) { + let package = store.get(package_id); + check_id_references(package); + + let Some(entry_id) = package.entry else { + return; + }; + + let reachable = collect_reachable_from_entry(store, package_id); + if level.is_post_udt_erase_or_later() { + let reachable_packages = collect_reachable_package_closure(package_id, &reachable); + for reachable_package_id in reachable_packages { + let reachable_package = store.get(reachable_package_id); + if reachable_package_id != package_id { + check_id_references(reachable_package); + } + check_package_udt_erase_invariants(reachable_package); + } + } + + check_reachable_invariants(store, package_id, &reachable, level); + + if level.is_post_defunc_or_later() { + check_expr_id_ownership(store, package_id, &reachable, entry_id); + } + + if level.is_post_return_unify_or_later() { + check_non_unit_block_tails(store, package_id, &reachable); + } + + // Check type invariants on the entry expression tree. + check_expr_types(store, package, entry_id, level); + + // After all passes, validate the entry exec graph. + if level == InvariantLevel::PostAll { + for (config, label) in [ + (ExecGraphConfig::NoDebug, "no_debug"), + (ExecGraphConfig::Debug, "debug"), + ] { + let nodes = package.entry_exec_graph.select_ref(config); + check_configured_exec_graph(package, nodes, "entry_exec_graph", label); + } + } +} + +/// Validates the package-wide surfaces that `udt_erase` mutates. +/// +/// The pass rewrites expression types and kinds, pattern types, block types, +/// and callable output types across every package in the reachable package +/// closure. This checker mirrors that mutation boundary without applying the +/// stronger target-package-only assumptions from later passes. +fn check_package_udt_erase_invariants(package: &Package) { + for (expr_id, _expr) in &package.exprs { + check_expr_udt_erase_invariants(package, expr_id); + } + + for (pat_id, pat) in &package.pats { + check_type_udt_erase_invariants(&pat.ty, &format!("Pat {pat_id}")); + } + + for (block_id, block) in &package.blocks { + check_type_udt_erase_invariants(&block.ty, &format!("Block {block_id}")); + } + + for (item_id, item) in &package.items { + if let ItemKind::Callable(decl) = &item.kind { + check_type_udt_erase_invariants(&decl.output, &format!("Callable {item_id} output")); + } + } +} + +/// Validates that a single expression satisfies post-UDT-erasure invariants: +/// no `Ty::Udt` in its type, no `ExprKind::Struct`, no `Field::Path` in +/// `UpdateField`/`AssignField`, and `Field::Path` only on tuple-typed records. +/// +/// # Panics +/// +/// Panics with a descriptive message if any UDT-erasure invariant is violated. +fn check_expr_udt_erase_invariants(package: &Package, expr_id: ExprId) { + let expr = package.get_expr(expr_id); + check_type_udt_erase_invariants(&expr.ty, &format!("Expr {expr_id}")); + + if matches!(&expr.kind, ExprKind::Struct(_, _, _)) { + panic!( + "PostUdtErase invariant violation: Expr {expr_id} contains \ + ExprKind::Struct after UDT erasure" + ); + } + + if let ExprKind::UpdateField(_, Field::Path(_), _) + | ExprKind::AssignField(_, Field::Path(_), _) = &expr.kind + { + panic!( + "PostUdtErase invariant violation: Expr {expr_id} contains \ + Field::Path in UpdateField/AssignField after UDT erasure" + ); + } + + if let ExprKind::Field(record_id, Field::Path(_)) = &expr.kind { + let record = package.get_expr(*record_id); + assert!( + matches!(&record.ty, Ty::Tuple(_)), + "PostUdtErase invariant violation: Expr {expr_id} has Field::Path \ + on non-tuple record Expr {record_id} (type: {:?})", + record.ty, + ); + } +} + +/// Recursively validates that a type contains no `Ty::Udt` variants. +/// +/// # Panics +/// +/// Panics if `Ty::Udt` is found anywhere within the type tree. +fn check_type_udt_erase_invariants(ty: &Ty, context: &str) { + match ty { + Ty::Array(inner) => check_type_udt_erase_invariants(inner, context), + Ty::Tuple(items) => { + for item in items { + check_type_udt_erase_invariants(item, context); + } + } + Ty::Arrow(arrow) => { + check_type_udt_erase_invariants(&arrow.input, context); + check_type_udt_erase_invariants(&arrow.output, context); + } + Ty::Udt(_) => { + panic!("{context} contains Ty::Udt after UDT erasure"); + } + Ty::Prim(_) | Ty::Param(_) | Ty::Infer(_) | Ty::Err => {} + } +} + +/// Verifies that every reachable non-Unit callable body block and nested block +/// expression ends in a trailing expression whose type matches the block type. +/// +/// This dispatcher fans out to `check_callable_non_unit_block_tails` for every +/// reachable callable, then runs `check_nested_block_expr_tails` on the entry +/// expression so nested block expressions outside callable bodies are covered +/// too. +/// +/// This invariant is only valid after return unification has collapsed terminal +/// wrappers and for all later pipeline checkpoints. +/// +/// # Panics +/// +/// Panics with a descriptive message if any non-Unit block lacks a matching +/// trailing `StmtKind::Expr`. +pub(crate) fn check_non_unit_block_tails( + store: &PackageStore, + package_id: qsc_fir::fir::PackageId, + reachable: &FxHashSet, +) { + let package = store.get(package_id); + let Some(entry_id) = package.entry else { + return; + }; + + for item_id in reachable { + if item_id.package != package_id { + continue; + } + + let item_pkg = store.get(item_id.package); + let item = item_pkg.get_item(item_id.item); + if let ItemKind::Callable(decl) = &item.kind { + check_callable_non_unit_block_tails(item_pkg, decl); + } + } + + check_nested_block_expr_tails(package, entry_id, "entry expression"); +} + +/// Checks the root blocks for a callable body and each explicit specialization, +/// then re-walks the callable implementation to validate every nested block +/// expression through `check_non_unit_block_tail`. +fn check_callable_non_unit_block_tails(package: &Package, decl: &CallableDecl) { + let callable_name = decl.name.name.to_string(); + + match &decl.implementation { + CallableImpl::Spec(spec_impl) => { + check_spec_block_tail( + package, + &spec_impl.body, + &format!("callable '{callable_name}' body"), + ); + + for (label, spec) in [ + ("adj", &spec_impl.adj), + ("ctl", &spec_impl.ctl), + ("ctl_adj", &spec_impl.ctl_adj), + ] { + if let Some(spec) = spec { + check_spec_block_tail( + package, + spec, + &format!("callable '{callable_name}' {label}"), + ); + } + } + } + CallableImpl::SimulatableIntrinsic(spec) => { + check_spec_block_tail( + package, + spec, + &format!("callable '{callable_name}' simulatable intrinsic"), + ); + } + CallableImpl::Intrinsic => {} + } + + crate::walk_utils::for_each_expr_in_callable_impl( + package, + &decl.implementation, + &mut |expr_id, expr| { + if let ExprKind::Block(block_id) = &expr.kind { + check_non_unit_block_tail( + package, + *block_id, + &format!("callable '{callable_name}' Expr {expr_id}"), + ); + } + }, + ); +} + +/// Small adapter that routes a specialization root block into the general +/// non-Unit tail checker. +fn check_spec_block_tail(package: &Package, spec: &SpecDecl, context: &str) { + check_non_unit_block_tail(package, spec.block, context); +} + +/// Walks an expression tree and applies `check_non_unit_block_tail` to every +/// nested `ExprKind::Block` it finds. +fn check_nested_block_expr_tails(package: &Package, expr_id: ExprId, context: &str) { + crate::walk_utils::for_each_expr(package, expr_id, &mut |nested_expr_id, expr| { + if let ExprKind::Block(block_id) = &expr.kind { + check_non_unit_block_tail( + package, + *block_id, + &format!("{context} Expr {nested_expr_id}"), + ); + } + }); +} + +/// Validates the trailing statement shape for a single non-Unit block. +/// +/// This is the leaf helper used by the higher-level non-Unit block-tail +/// walkers once they have identified a specific block that should already be +/// in single-exit form. +/// +/// # Panics +/// +/// Panics if the block has a non-Unit type but is empty, ends in a non-Expr +/// statement, or ends in an expression whose type does not match the block +/// type. +fn check_non_unit_block_tail(package: &Package, block_id: BlockId, context: &str) { + let block = package.get_block(block_id); + if block.ty == Ty::UNIT { + return; + } + + let Some(&stmt_id) = block.stmts.last() else { + panic!( + "Non-Unit block-tail invariant violation: {context} Block {block_id} has type {:?} but has no trailing statement", + block.ty, + ); + }; + + let stmt = package.get_stmt(stmt_id); + let expr_id = match &stmt.kind { + StmtKind::Expr(expr_id) => *expr_id, + StmtKind::Semi(expr_id) => { + panic!( + "Non-Unit block-tail invariant violation: {context} Block {block_id} has type {:?} but ends with Semi Expr {expr_id}", + block.ty, + ); + } + StmtKind::Local(..) => { + panic!( + "Non-Unit block-tail invariant violation: {context} Block {block_id} has type {:?} but ends with a Local statement", + block.ty, + ); + } + StmtKind::Item(_) => { + panic!( + "Non-Unit block-tail invariant violation: {context} Block {block_id} has type {:?} but ends with an Item statement", + block.ty, + ); + } + }; + + let expr_ty = &package.get_expr(expr_id).ty; + assert!( + expr_ty == &block.ty, + "Non-Unit block-tail invariant violation: {context} Block {block_id} has type {:?} but trailing Expr {expr_id} has type {expr_ty:?}", + block.ty, + ); +} + +/// Verifies that all IDs referenced inside blocks, stmts, exprs, and pats +/// actually exist in their respective `IndexMap`s. +fn check_id_references(package: &Package) { + for (block_id, block) in &package.blocks { + assert_eq!( + block.id, block_id, + "Block {block_id} has mismatched id field" + ); + for &stmt_id in &block.stmts { + assert!( + package.stmts.get(stmt_id).is_some(), + "Block {block_id} references nonexistent Stmt {stmt_id}" + ); + } + } + + for (stmt_id, stmt) in &package.stmts { + assert_eq!(stmt.id, stmt_id, "Stmt {stmt_id} has mismatched id field"); + match &stmt.kind { + StmtKind::Expr(e) | StmtKind::Semi(e) => { + assert!( + package.exprs.get(*e).is_some(), + "Stmt {stmt_id} references nonexistent Expr {e}" + ); + } + StmtKind::Local(_, pat, expr) => { + assert!( + package.pats.get(*pat).is_some(), + "Stmt {stmt_id} references nonexistent Pat {pat}" + ); + assert!( + package.exprs.get(*expr).is_some(), + "Stmt {stmt_id} references nonexistent Expr {expr}" + ); + } + StmtKind::Item(_) => { + // After item DCE, `StmtKind::Item` stmts may reference + // items that were removed. This is benign: the exec graph + // never executes through item-definition stmts. + } + } + } + + for (expr_id, expr) in &package.exprs { + assert_eq!(expr.id, expr_id, "Expr {expr_id} has mismatched id field"); + check_expr_sub_ids(package, expr_id, &expr.kind); + } +} + +/// Checks that every child ID referenced by an expression kind exists in the +/// corresponding package map. +/// +/// `check_id_references` delegates expression-specific validation here after it +/// has confirmed the top-level expression record itself is present. +/// +/// # Panics +/// +/// Panics if any sub-expression or block ID referenced by `kind` is missing. +fn check_expr_sub_ids(package: &Package, parent_expr: ExprId, kind: &ExprKind) { + let assert_expr = |e: ExprId| { + assert!( + package.exprs.get(e).is_some(), + "Expr {parent_expr} references nonexistent sub-Expr {e}" + ); + }; + let assert_block = |b: BlockId| { + assert!( + package.blocks.get(b).is_some(), + "Expr {parent_expr} references nonexistent Block {b}" + ); + }; + + match kind { + ExprKind::Array(es) | ExprKind::ArrayLit(es) | ExprKind::Tuple(es) => { + for &e in es { + assert_expr(e); + } + } + ExprKind::ArrayRepeat(a, b) + | ExprKind::Assign(a, b) + | ExprKind::AssignOp(_, a, b) + | ExprKind::BinOp(_, a, b) + | ExprKind::Call(a, b) + | ExprKind::Index(a, b) + | ExprKind::AssignField(a, _, b) + | ExprKind::UpdateField(a, _, b) => { + assert_expr(*a); + assert_expr(*b); + } + ExprKind::AssignIndex(a, b, c) | ExprKind::UpdateIndex(a, b, c) => { + assert_expr(*a); + assert_expr(*b); + assert_expr(*c); + } + ExprKind::Block(block_id) => assert_block(*block_id), + ExprKind::Fail(e) | ExprKind::Field(e, _) | ExprKind::Return(e) | ExprKind::UnOp(_, e) => { + assert_expr(*e); + } + ExprKind::If(cond, body, otherwise) => { + assert_expr(*cond); + assert_expr(*body); + if let Some(e) = otherwise { + assert_expr(*e); + } + } + ExprKind::Range(s, st, e) => { + if let Some(x) = s { + assert_expr(*x); + } + if let Some(x) = st { + assert_expr(*x); + } + if let Some(x) = e { + assert_expr(*x); + } + } + ExprKind::Struct(_, copy, fields) => { + if let Some(c) = copy { + assert_expr(*c); + } + for fa in fields { + assert_expr(fa.value); + } + } + ExprKind::String(components) => { + for c in components { + if let qsc_fir::fir::StringComponent::Expr(e) = c { + assert_expr(*e); + } + } + } + ExprKind::While(cond, block) => { + assert_expr(*cond); + assert_block(*block); + } + ExprKind::Closure(_, _) | ExprKind::Hole | ExprKind::Lit(_) | ExprKind::Var(_, _) => {} + } +} + +/// Applies stage-gated callable checks to each reachable callable in the +/// target package. +/// +/// Depending on `level`, this dispatcher invokes: +/// - `check_type_invariants` on callable output types. +/// - `check_no_arrow_params` once defunctionalization should have removed +/// callable-valued parameters. Pinned items are excluded from this check +/// because they are specialization targets that intentionally retain +/// arrow-typed parameters for callable-args codegen. +/// - `check_callable_input_pattern_shapes` once SROA and argument promotion may +/// have synthesized tuple-shaped inputs. +/// - `check_no_returns` once return unification should have removed +/// `ExprKind::Return`. +/// - `check_spec_decl_types` on the body and explicit specializations. +/// - `check_local_var_consistency` to ensure every local reference is still +/// backed by a binder. +/// - `check_spec_exec_graph` once exec graphs have been rebuilt at `PostAll`. +fn check_reachable_invariants( + store: &PackageStore, + target_package_id: qsc_fir::fir::PackageId, + reachable: &FxHashSet, + level: InvariantLevel, +) { + for item_id in reachable { + // Only check invariants on items in the target package. Cross-package + // items (e.g. stdlib) are not transformed by the surrounding stages + // and may still contain Ty::Param, Arrow types, or closures. Their + // package-wide UDT-erasure invariants are checked separately. + if item_id.package != target_package_id { + continue; + } + let item_pkg = store.get(item_id.package); + let item = item_pkg.get_item(item_id.item); + if let ItemKind::Callable(decl) = &item.kind { + // All reachable callables have been through the full pipeline + // via the entry expression and should pass all stage-specific + // invariant checks. + check_type_invariants(&decl.output, level, "callable output type"); + + if level.is_post_defunc_or_later() { + check_no_arrow_params(item_pkg, decl); + } + + if level.is_post_arg_promote_or_later() { + check_callable_input_pattern_shapes(item_pkg, decl); + } + + if level.is_post_return_unify_or_later() { + check_no_returns(item_pkg, decl); + } + + match &decl.implementation { + CallableImpl::Spec(spec_impl) => { + check_spec_decl_types(store, item_pkg, &spec_impl.body, level); + for spec in functored_specs(spec_impl) { + check_spec_decl_types(store, item_pkg, spec, level); + } + } + CallableImpl::SimulatableIntrinsic(spec) => { + check_spec_decl_types(store, item_pkg, spec, level); + } + CallableImpl::Intrinsic => {} + } + + if level.is_post_mono_or_later() { + check_local_var_consistency(item_pkg, decl); + } + + // After all passes, validate exec graph structural integrity. + if level == InvariantLevel::PostAll { + let name = &decl.name.name; + match &decl.implementation { + CallableImpl::Spec(spec_impl) => { + check_spec_exec_graph(item_pkg, &spec_impl.body, &format!("{name}/body")); + for (label, spec) in [ + ("adj", &spec_impl.adj), + ("ctl", &spec_impl.ctl), + ("ctl_adj", &spec_impl.ctl_adj), + ] { + if let Some(s) = spec { + check_spec_exec_graph(item_pkg, s, &format!("{name}/{label}")); + } + } + } + CallableImpl::SimulatableIntrinsic(spec) => { + check_spec_exec_graph(item_pkg, spec, &format!("{name}/sim_intrinsic")); + } + CallableImpl::Intrinsic => {} + } + } + } + } +} + +/// Validates that callable input patterns no longer expose arrow-typed leaves. +/// +/// The actual recursion lives in `check_pat_for_arrow` so tuple-shaped inputs +/// are checked all the way down to their leaves. +fn check_no_arrow_params(package: &Package, callable: &qsc_fir::fir::CallableDecl) { + check_pat_for_arrow(package, callable.input); +} + +/// Verifies that no `ExprKind::Return` nodes remain in a callable's body. +/// +/// # Panics +/// +/// Panics if any return expression is found. +fn check_no_returns(package: &Package, decl: &CallableDecl) { + crate::walk_utils::for_each_expr_in_callable_impl( + package, + &decl.implementation, + &mut |_expr_id, expr| { + assert!( + !matches!(expr.kind, ExprKind::Return(_)), + "PostReturnUnify invariant violation: ExprKind::Return found after return unification pass in callable '{}'", + decl.name.name + ); + }, + ); +} + +/// Recursively validates that a pattern tree contains no arrow-typed leaves. +/// +/// This helper is used by `check_no_arrow_params` so tuple-shaped callable +/// inputs are checked all the way down to their bound and discard leaves. +/// +/// # Panics +/// +/// Panics if any bound or discarded leaf still carries `Ty::Arrow`. +fn check_pat_for_arrow(package: &Package, pat_id: PatId) { + let pat = package.get_pat(pat_id); + match &pat.kind { + PatKind::Tuple(pats) => { + for &sub_pat_id in pats { + check_pat_for_arrow(package, sub_pat_id); + } + } + PatKind::Bind(_) => { + assert!( + !matches!(pat.ty, Ty::Arrow(_)), + "PostDefunc invariant violation: Arrow-typed parameter remains in callable input (Pat {pat_id})" + ); + } + PatKind::Discard => { + assert!( + !matches!(pat.ty, Ty::Arrow(_)), + "PostDefunc invariant violation: Arrow-typed discard parameter in callable input (Pat {pat_id})" + ); + } + } +} + +/// Validates the tuple-pattern shape of a callable's primary input pattern and +/// any specialization-specific input patterns. +/// +/// This check becomes relevant after tuple-decomposing stages such as SROA and +/// argument promotion, which may synthesize tuple-shaped inputs that must still +/// mirror the callable input types exactly. +/// +/// # Panics +/// +/// Panics if any callable or specialization input pattern has tuple structure +/// that does not match its declared type. +fn check_callable_input_pattern_shapes(package: &Package, decl: &CallableDecl) { + let callable_name = decl.name.name.to_string(); + check_tuple_pat_shape_matches_type( + package, + decl.input, + &format!("callable '{callable_name}' input"), + ); + + match &decl.implementation { + CallableImpl::Spec(spec_impl) => { + for (label, input_pat) in [ + ("body", spec_impl.body.input), + ("adj", spec_impl.adj.as_ref().and_then(|spec| spec.input)), + ("ctl", spec_impl.ctl.as_ref().and_then(|spec| spec.input)), + ( + "ctl_adj", + spec_impl.ctl_adj.as_ref().and_then(|spec| spec.input), + ), + ] { + if let Some(pat_id) = input_pat { + check_tuple_pat_shape_matches_type( + package, + pat_id, + &format!("callable '{callable_name}' {label} input"), + ); + } + } + } + CallableImpl::SimulatableIntrinsic(spec) => { + if let Some(pat_id) = spec.input { + check_tuple_pat_shape_matches_type( + package, + pat_id, + &format!("callable '{callable_name}' simulatable intrinsic input"), + ); + } + } + CallableImpl::Intrinsic => {} + } +} + +/// Validates the tuple-pattern shape of `pat_id` against its declared type. +/// +/// Recurses into `PatKind::Tuple` and requires the pattern arity to match the +/// `Ty::Tuple` element count exactly; each sub-pattern's type must equal the +/// corresponding tuple element type. `PatKind::Bind` and `PatKind::Discard` +/// are accepted unconditionally. `context` appears in panic messages to +/// disambiguate the calling site. +fn check_tuple_pat_shape_matches_type(package: &Package, pat_id: PatId, context: &str) { + let pat = package.get_pat(pat_id); + match &pat.kind { + PatKind::Tuple(pats) => { + let Ty::Tuple(elem_tys) = &pat.ty else { + panic!( + "Tuple pattern/type invariant violation: {context} Pat {pat_id} is tuple-shaped but has non-tuple type {:?}", + pat.ty, + ); + }; + + assert!( + pats.len() == elem_tys.len(), + "Tuple pattern/type invariant violation: {context} Pat {pat_id} has {} tuple elements but type has {} elements", + pats.len(), + elem_tys.len(), + ); + + for (index, (&sub_pat_id, elem_ty)) in pats.iter().zip(elem_tys.iter()).enumerate() { + let sub_pat_ty = &package.get_pat(sub_pat_id).ty; + assert!( + sub_pat_ty == elem_ty, + "Tuple pattern/type invariant violation: {context} Pat {pat_id} element {index} Pat {sub_pat_id} has type {sub_pat_ty:?} but tuple type expects {elem_ty:?}", + ); + check_tuple_pat_shape_matches_type(package, sub_pat_id, context); + } + } + PatKind::Bind(_) | PatKind::Discard => {} + } +} + +/// Asserts that no tuple-bound local leaf retains an arrow-typed field. +/// +/// Recurses into `PatKind::Tuple` to reach every `Bind`/`Discard` leaf, then +/// delegates to `tuple_type_contains_arrow` on the leaf's declared type. +fn check_local_pat_for_nested_tuple_arrow(package: &Package, pat_id: PatId) { + let pat = package.get_pat(pat_id); + match &pat.kind { + PatKind::Tuple(pats) => { + for &sub_pat_id in pats { + check_local_pat_for_nested_tuple_arrow(package, sub_pat_id); + } + } + PatKind::Bind(_) | PatKind::Discard => { + assert!( + !tuple_type_contains_arrow(&pat.ty), + "PostDefunc invariant violation: tuple-bound local retains an arrow-typed field (Pat {pat_id})" + ); + } + } +} + +/// Returns `true` when a `Ty::Tuple` contains any arrow-typed field, +/// transitively through nested tuples. Non-tuple types yield `false`. +fn tuple_type_contains_arrow(ty: &Ty) -> bool { + match ty { + Ty::Tuple(items) => items.iter().any(tuple_field_type_contains_arrow), + _ => false, + } +} + +/// Returns `true` when a tuple field type is itself an arrow or a tuple that +/// transitively contains one. Used by `tuple_type_contains_arrow` to walk +/// into nested tuple fields. +fn tuple_field_type_contains_arrow(ty: &Ty) -> bool { + match ty { + Ty::Arrow(_) => true, + Ty::Tuple(items) => items.iter().any(tuple_field_type_contains_arrow), + _ => false, + } +} + +/// Drives the statement walk for a single specialization body by forwarding +/// each statement to `check_stmt_types`. +fn check_spec_decl_types( + store: &PackageStore, + package: &Package, + spec: &qsc_fir::fir::SpecDecl, + level: InvariantLevel, +) { + let block = package.get_block(spec.block); + for &stmt_id in &block.stmts { + check_stmt_types(store, package, stmt_id, level); + } +} + +/// Applies the statement-local checks for a specialization block. +/// +/// For each local binding, this layers: +/// - `check_pat_types` on the bound pattern type. +/// - `check_tuple_pat_shape_matches_type` after tuple-decomposing stages. +/// - `check_local_pat_for_nested_tuple_arrow` after SROA (arrow types may +/// appear inside tuples between UDT erasure and SROA). +/// - `check_expr_types` on the initializer expression. +/// - a final initializer-type equality assertion at `PostAll`. +/// +/// Standalone expression statements are delegated directly to +/// `check_expr_types`. +fn check_stmt_types( + store: &PackageStore, + package: &Package, + stmt_id: qsc_fir::fir::StmtId, + level: InvariantLevel, +) { + let stmt = package.get_stmt(stmt_id); + match &stmt.kind { + StmtKind::Expr(e) | StmtKind::Semi(e) => check_expr_types(store, package, *e, level), + StmtKind::Local(_, pat, expr) => { + check_pat_types(package, *pat, level); + if level.is_post_sroa_or_later() { + check_tuple_pat_shape_matches_type(package, *pat, "local binding"); + check_local_pat_for_nested_tuple_arrow(package, *pat); + } + check_expr_types(store, package, *expr, level); + + if level == InvariantLevel::PostReturnUnify || level == InvariantLevel::PostAll { + let pat_ty = &package.get_pat(*pat).ty; + let init_ty = &package.get_expr(*expr).ty; + // Ty::Infer and Ty::Err should never appear at PostAll — all + // passes must have resolved these types by then. At + // PostReturnUnify, later passes may still need to resolve + // them, so skip the type-equality check for those types. + let has_unresolved = matches!(pat_ty, Ty::Err | Ty::Infer(_)) + || matches!(init_ty, Ty::Err | Ty::Infer(_)); + if !has_unresolved || level == InvariantLevel::PostAll { + assert!( + pat_ty == init_ty, + "PostReturnUnify invariant violation: local binding Pat {pat} has type \ + {pat_ty:?} but initializer Expr {expr} has type {init_ty:?}", + ); + } + } + } + StmtKind::Item(_) => {} + } +} + +/// Walks the full subtree rooted at `expr_id` and forwards every visited node +/// to `check_expr_type`. +fn check_expr_types( + store: &PackageStore, + package: &Package, + expr_id: ExprId, + level: InvariantLevel, +) { + crate::walk_utils::for_each_expr(package, expr_id, &mut |expr_id, _expr| { + check_expr_type(store, package, expr_id, level); + }); +} + +/// Applies node-local expression invariants. +/// +/// This always starts with `check_type_invariants` on the expression's own +/// type and then layers stage-specific structural checks on the expression +/// kind itself. +/// +/// The `PostUdtErase`-era expression-kind assertions here (for +/// [`ExprKind::Struct`], [`Field::Path`] in `UpdateField`/`AssignField`, and +/// [`Field::Path`] on non-tuple records) intentionally overlap with +/// `check_package_udt_erase_invariants`: this walker fires on every +/// reachable expression in the target package, while the package-wide walker +/// visits every expression in every reachable package. Both paths must agree +/// so a regression caught in either scope produces the same diagnostic. +fn check_expr_type( + store: &PackageStore, + package: &Package, + expr_id: ExprId, + level: InvariantLevel, +) { + let expr = package.get_expr(expr_id); + check_type_invariants(&expr.ty, level, &format!("Expr {expr_id}")); + + // After defunctionalization, no closures should remain in reachable code. + if level.is_post_defunc_or_later() { + assert!( + !matches!(&expr.kind, ExprKind::Closure(_, _)), + "Expr {expr_id} is a Closure after defunctionalization" + ); + } + + // PostMono: no remaining generic args on Var references. + if level.is_post_mono_or_later() + && let ExprKind::Var(_, args) = &expr.kind + { + assert!( + args.is_empty(), + "PostMono invariant violation: Expr {expr_id} still has non-empty generic args" + ); + } + + // After UDT erasure, all Struct expressions must have been lowered. + if level.is_post_udt_erase_or_later() { + if matches!(&expr.kind, ExprKind::Struct(_, _, _)) { + panic!( + "PostUdtErase invariant violation: Expr {expr_id} contains \ + ExprKind::Struct after UDT erasure" + ); + } + + // Field::Path references UDT field paths that must be lowered by udt_erase. + if let ExprKind::UpdateField(_, Field::Path(_), _) + | ExprKind::AssignField(_, Field::Path(_), _) = &expr.kind + { + panic!( + "PostUdtErase invariant violation: Expr {expr_id} contains \ + Field::Path in UpdateField/AssignField after UDT erasure" + ); + } + + // After UDT erasure, every Field::Path target must be a Tuple. + if let ExprKind::Field(record_id, Field::Path(_)) = &expr.kind { + let record = package.get_expr(*record_id); + assert!( + matches!(&record.ty, Ty::Tuple(_)), + "PostUdtErase invariant violation: Expr {expr_id} has Field::Path \ + on non-tuple record Expr {record_id} (type: {:?})", + record.ty, + ); + } + } + + // After tuple comparison lowering, no BinOp(Eq/Neq) on non-empty tuple operands. + if level.is_post_tuple_comp_lower_or_later() + && let ExprKind::BinOp(BinOp::Eq | BinOp::Neq, lhs_id, _) = &expr.kind + { + let lhs_ty = &package.get_expr(*lhs_id).ty; + if let Ty::Tuple(elems) = lhs_ty { + assert!( + elems.is_empty(), + "PostTupleCompLower invariant violation: Expr {expr_id} has \ + BinOp(Eq/Neq) on tuple-typed operands" + ); + } + } + + // After defunctionalization, tuple expressions must have types with matching arity. + if level.is_post_defunc_or_later() + && let ExprKind::Tuple(es) = &expr.kind + && let Ty::Tuple(tys) = &expr.ty + { + assert!( + es.len() == tys.len(), + "Tuple arity mismatch: Expr {expr_id} has {} elements but type has {} elements", + es.len(), + tys.len() + ); + } + + if level.is_post_arg_promote_or_later() + && let ExprKind::Call(callee_id, arg_id) = &expr.kind + { + check_call_shape_matches_callee(store, package, expr_id, *callee_id, *arg_id); + } +} + +/// Verifies that a `ExprKind::Call` expression's argument type matches the +/// callee's declared input type and that the call's result type matches the +/// callee's declared output type. +/// +/// This is the post-`arg_promote` check that catches signature drift +/// introduced by tuple-decomposing stages. +fn check_call_shape_matches_callee( + store: &PackageStore, + package: &Package, + call_expr_id: ExprId, + callee_id: ExprId, + arg_id: ExprId, +) { + let arg = package.get_expr(arg_id); + + let Some((expected_input, expected_output)) = resolve_call_signature(store, package, callee_id) + else { + let callee = package.get_expr(callee_id); + panic!( + "PostArgPromote/PostAll call invariant violation: Expr {call_expr_id} calls Expr \ + {callee_id} whose signature cannot be resolved from callee type {:?}", + callee.ty, + ); + }; + + assert!( + arg.ty == expected_input, + "PostArgPromote/PostAll call invariant violation: Expr {call_expr_id} passes Expr \ + {arg_id} with type {:?} to callee Expr {callee_id} expecting input type \ + {expected_input:?}", + arg.ty, + ); + + let call = package.get_expr(call_expr_id); + assert!( + call.ty == expected_output, + "PostArgPromote/PostAll call invariant violation: Expr {call_expr_id} has type {:?} \ + but callee Expr {callee_id} returns {expected_output:?}", + call.ty, + ); +} + +/// Resolves a callee expression to its `(input_ty, output_ty)` signature. +/// +/// Handles the two callee forms that can appear after the pipeline runs: a +/// direct `Ty::Arrow`-typed expression (e.g., a captured callable value), and +/// an `ExprKind::Var(Res::Item, _)` pointing at a `Callable` item in any +/// package. Returns `None` when the callee is neither form; callers treat +/// `None` as an invariant violation. +fn resolve_call_signature( + store: &PackageStore, + package: &Package, + callee_id: ExprId, +) -> Option<(Ty, Ty)> { + let callee = package.get_expr(callee_id); + if let Ty::Arrow(arrow) = &callee.ty { + return Some(((*arrow.input).clone(), (*arrow.output).clone())); + } + + if let ExprKind::Var(Res::Item(item_id), _) = &callee.kind { + let callee_package = store.get(item_id.package); + let item = callee_package.get_item(item_id.item); + if let ItemKind::Callable(decl) = &item.kind { + let input_ty = callee_package.get_pat(decl.input).ty.clone(); + return Some((input_ty, decl.output.clone())); + } + } + + None +} + +/// Validates a pattern's declared type by delegating to +/// `check_type_invariants`. +fn check_pat_types(package: &Package, pat_id: PatId, level: InvariantLevel) { + let pat = package.get_pat(pat_id); + check_type_invariants(&pat.ty, level, &format!("Pat {pat_id}")); +} + +/// Recursively validates the stage-sensitive invariants for a type. +/// +/// This is the common type checker used by callable signatures, patterns, and +/// expressions. It enforces the type-form restrictions guaranteed by each +/// pipeline stage while walking into nested array, tuple, and arrow types. +/// +/// # Panics +/// +/// Panics when a type still contains a form that should have been eliminated by +/// the current invariant level, such as `Ty::Param`, `FunctorSet::Param`, or +/// `Ty::Udt`. +fn check_type_invariants(ty: &Ty, level: InvariantLevel, context: &str) { + match ty { + Ty::Param(_) => { + assert!( + !level.is_post_mono_or_later(), + "{context} contains Ty::Param after monomorphization" + ); + } + Ty::Arrow(arrow) => { + if level.is_post_mono_or_later() { + assert!( + !matches!(arrow.functors, FunctorSet::Param(_)), + "{context} contains FunctorSet::Param after monomorphization" + ); + } + if level.is_post_defunc_or_later() { + // `Ty::Arrow` leaves are allowed on callable outputs and + // cross-package items; the `PostDefunc` invariant targets + // arrow-typed callable *parameters*, enforced by + // `check_no_arrow_params`. + } + check_type_invariants(&arrow.input, level, context); + check_type_invariants(&arrow.output, level, context); + } + Ty::Array(inner) => check_type_invariants(inner, level, context), + Ty::Tuple(items) => { + for item in items { + check_type_invariants(item, level, context); + } + } + Ty::Udt(_) => { + assert!( + !level.is_post_udt_erase_or_later(), + "{context} contains Ty::Udt after UDT erasure" + ); + } + Ty::Infer(_) | Ty::Err => { + assert!( + level != InvariantLevel::PostAll, + "{context} contains unexpected Ty::Infer/Ty::Err — indicates a pass bug" + ); + } + Ty::Prim(_) => {} + } +} + +/// Verifies that every `Res::Local(id)` in a callable body refers to a +/// `LocalVarId` that is bound by: +/// - the callable's input pattern, +/// - a specialization input pattern, or +/// - a `PatKind::Bind` in a body-internal `StmtKind::Local`. +/// +/// # Panics +/// +/// Panics if a local reference is found that is not in the bound set. +fn check_local_var_consistency(package: &Package, decl: &CallableDecl) { + let mut bound: FxHashSet = FxHashSet::default(); + let mut refs: Vec<(ExprId, LocalVarId)> = Vec::new(); + + // Collect bindings from the callable's input pattern. + collect_pat_bindings(package, decl.input, &mut bound); + + match &decl.implementation { + CallableImpl::Spec(spec_impl) => { + for spec in std::iter::once(&spec_impl.body) + .chain(spec_impl.adj.iter()) + .chain(spec_impl.ctl.iter()) + .chain(spec_impl.ctl_adj.iter()) + { + if let Some(input_pat) = spec.input { + collect_pat_bindings(package, input_pat, &mut bound); + } + walk_block_for_locals(package, spec.block, &mut bound, &mut refs); + } + } + CallableImpl::SimulatableIntrinsic(spec) => { + if let Some(input_pat) = spec.input { + collect_pat_bindings(package, input_pat, &mut bound); + } + walk_block_for_locals(package, spec.block, &mut bound, &mut refs); + } + CallableImpl::Intrinsic => {} + } + + // Assert every referenced local is bound. + for (expr_id, var_id) in &refs { + assert!( + bound.contains(var_id), + "LocalVarId consistency: Expr {expr_id} references {var_id}, \ + which is not bound in callable \"{}\"", + decl.name.name, + ); + } +} + +/// Recursively collects all `LocalVarId`s from `PatKind::Bind` nodes. +fn collect_pat_bindings(package: &Package, pat_id: PatId, bound: &mut FxHashSet) { + let pat = package.get_pat(pat_id); + match &pat.kind { + PatKind::Bind(ident) => { + bound.insert(ident.id); + } + PatKind::Discard => {} + PatKind::Tuple(pats) => { + for &sub in pats { + collect_pat_bindings(package, sub, bound); + } + } + } +} + +/// Walks a block, collecting both bindings and local references. +fn walk_block_for_locals( + package: &Package, + block_id: BlockId, + bound: &mut FxHashSet, + refs: &mut Vec<(ExprId, LocalVarId)>, +) { + let block = package.get_block(block_id); + for &stmt_id in &block.stmts { + let stmt = package.get_stmt(stmt_id); + match &stmt.kind { + StmtKind::Expr(e) | StmtKind::Semi(e) => { + walk_expr_for_locals(package, *e, bound, refs); + } + StmtKind::Local(_, pat, expr) => { + collect_pat_bindings(package, *pat, bound); + walk_expr_for_locals(package, *expr, bound, refs); + } + StmtKind::Item(_) => {} + } + } +} + +/// Walks an expression tree, recording `Res::Local` references and recursing +/// into sub-expressions and nested blocks. +fn walk_expr_for_locals( + package: &Package, + expr_id: ExprId, + bound: &mut FxHashSet, + refs: &mut Vec<(ExprId, LocalVarId)>, +) { + let expr = package.get_expr(expr_id); + + // Record local references. + match &expr.kind { + ExprKind::Var(Res::Local(id), _) => refs.push((expr_id, *id)), + ExprKind::Closure(ids, _) => { + for id in ids { + refs.push((expr_id, *id)); + } + } + _ => {} + } + + // Recurse into sub-expressions and sub-blocks. + match &expr.kind { + ExprKind::Array(es) | ExprKind::ArrayLit(es) | ExprKind::Tuple(es) => { + for &e in es { + walk_expr_for_locals(package, e, bound, refs); + } + } + ExprKind::ArrayRepeat(a, b) + | ExprKind::Assign(a, b) + | ExprKind::AssignOp(_, a, b) + | ExprKind::BinOp(_, a, b) + | ExprKind::Call(a, b) + | ExprKind::Index(a, b) + | ExprKind::AssignField(a, _, b) + | ExprKind::UpdateField(a, _, b) => { + walk_expr_for_locals(package, *a, bound, refs); + walk_expr_for_locals(package, *b, bound, refs); + } + ExprKind::AssignIndex(a, b, c) | ExprKind::UpdateIndex(a, b, c) => { + walk_expr_for_locals(package, *a, bound, refs); + walk_expr_for_locals(package, *b, bound, refs); + walk_expr_for_locals(package, *c, bound, refs); + } + ExprKind::Block(block_id) => walk_block_for_locals(package, *block_id, bound, refs), + ExprKind::Fail(e) | ExprKind::Field(e, _) | ExprKind::Return(e) | ExprKind::UnOp(_, e) => { + walk_expr_for_locals(package, *e, bound, refs); + } + ExprKind::If(cond, body, otherwise) => { + walk_expr_for_locals(package, *cond, bound, refs); + walk_expr_for_locals(package, *body, bound, refs); + if let Some(e) = otherwise { + walk_expr_for_locals(package, *e, bound, refs); + } + } + ExprKind::Range(s, st, e) => { + if let Some(x) = s { + walk_expr_for_locals(package, *x, bound, refs); + } + if let Some(x) = st { + walk_expr_for_locals(package, *x, bound, refs); + } + if let Some(x) = e { + walk_expr_for_locals(package, *x, bound, refs); + } + } + ExprKind::Struct(_, copy, fields) => { + if let Some(c) = copy { + walk_expr_for_locals(package, *c, bound, refs); + } + for fa in fields { + walk_expr_for_locals(package, fa.value, bound, refs); + } + } + ExprKind::String(components) => { + for c in components { + if let qsc_fir::fir::StringComponent::Expr(e) = c { + walk_expr_for_locals(package, *e, bound, refs); + } + } + } + ExprKind::While(cond, block) => { + walk_expr_for_locals(package, *cond, bound, refs); + walk_block_for_locals(package, *block, bound, refs); + } + ExprKind::Closure(_, _) | ExprKind::Hole | ExprKind::Lit(_) | ExprKind::Var(_, _) => {} + } +} + +/// Validates structural integrity of a single configured exec graph. +/// +/// # Panics +/// +/// Panics with a descriptive message if any invariant is violated. +fn check_configured_exec_graph( + package: &Package, + nodes: &[ExecGraphNode], + context: &str, + config_label: &str, +) { + let len = nodes.len(); + assert!( + len > 0, + "Exec graph for {context} ({config_label}) is empty" + ); + + // Invariant E: graph terminates correctly. + match config_label { + "no_debug" => assert!( + matches!(nodes[len - 1], ExecGraphNode::Ret), + "Exec graph for {context} ({config_label}) does not end with Ret, found {:?}", + nodes[len - 1], + ), + "debug" => assert!( + matches!( + nodes[len - 1], + ExecGraphNode::Debug(ExecGraphDebugNode::RetFrame) + ), + "Exec graph for {context} ({config_label}) does not end with RetFrame, found {:?}", + nodes[len - 1], + ), + _ => {} + } + + for (i, node) in nodes.iter().enumerate() { + match node { + // Invariant A: jump targets within bounds. + ExecGraphNode::Jump(idx) + | ExecGraphNode::JumpIf(idx) + | ExecGraphNode::JumpIfNot(idx) => { + assert!( + (*idx as usize) < len, + "Exec graph for {context} ({config_label}): node {i} has jump target {idx} >= len {len}" + ); + } + // Invariant B: Expr references valid ExprId. + ExecGraphNode::Expr(expr_id) => { + assert!( + package.exprs.get(*expr_id).is_some(), + "Exec graph for {context} ({config_label}): node {i} references nonexistent Expr {expr_id}" + ); + } + // Invariant C: Bind references valid PatId. + ExecGraphNode::Bind(pat_id) => { + assert!( + package.pats.get(*pat_id).is_some(), + "Exec graph for {context} ({config_label}): node {i} references nonexistent Pat {pat_id}" + ); + } + // Invariant D: debug node ID references are valid. + ExecGraphNode::Debug(debug_node) => match debug_node { + ExecGraphDebugNode::Stmt(stmt_id) => { + assert!( + package.stmts.get(*stmt_id).is_some(), + "Exec graph for {context} ({config_label}): node {i} references nonexistent Stmt {stmt_id}" + ); + } + ExecGraphDebugNode::PushLoopScope(expr_id) => { + assert!( + package.exprs.get(*expr_id).is_some(), + "Exec graph for {context} ({config_label}): node {i} PushLoopScope references nonexistent Expr {expr_id}" + ); + } + ExecGraphDebugNode::BlockEnd(block_id) => { + assert!( + package.blocks.get(*block_id).is_some(), + "Exec graph for {context} ({config_label}): node {i} BlockEnd references nonexistent Block {block_id}" + ); + } + ExecGraphDebugNode::PushScope + | ExecGraphDebugNode::PopScope + | ExecGraphDebugNode::RetFrame + | ExecGraphDebugNode::LoopIteration => {} + }, + ExecGraphNode::Store | ExecGraphNode::Unit | ExecGraphNode::Ret => {} + } + } +} + +/// Validates both configurations of a spec's exec graph. +/// +/// This fans out to `check_configured_exec_graph` for the compact and debug +/// views so both serialized forms are kept structurally consistent. +fn check_spec_exec_graph(package: &Package, spec: &SpecDecl, context: &str) { + for (config, label) in [ + (ExecGraphConfig::NoDebug, "no_debug"), + (ExecGraphConfig::Debug, "debug"), + ] { + let nodes = spec.exec_graph.select_ref(config); + check_configured_exec_graph(package, nodes, context, label); + } +} + +/// Verifies two ownership properties of `ExprId`s after defunctionalization: +/// +/// 1. **Per-spec uniqueness**: No `ExprId` appears in more than one +/// specialization body across all reachable callables. +/// 2. **Entry-vs-spec disjointness**: `ExprId`s reachable from the entry +/// expression are disjoint from those inside any specialization body. +/// +/// These properties ensure that RCA can assign per-arity `ComputeKind` +/// entries without collision. Defunctionalization's closure cleanup pass +/// is the primary mechanism that establishes property (2) for producer +/// function bodies that originally contained closure nodes. +/// +/// # Panics +/// +/// Panics with a descriptive message if any `ExprId` is shared. +fn check_expr_id_ownership( + store: &PackageStore, + package_id: PackageId, + reachable: &FxHashSet, + entry_id: ExprId, +) { + let package = store.get(package_id); + + // Map each ExprId to the (item, spec_label) that owns it. + let mut seen: FxHashMap = FxHashMap::default(); + + for item_id in reachable { + if item_id.package != package_id { + continue; + } + let item = package.get_item(item_id.item); + let ItemKind::Callable(decl) = &item.kind else { + continue; + }; + + let specs: Vec<(&SpecDecl, &'static str)> = match &decl.implementation { + CallableImpl::Spec(spec_impl) => { + let mut v = vec![(&spec_impl.body, "body")]; + if let Some(adj) = &spec_impl.adj { + v.push((adj, "adj")); + } + if let Some(ctl) = &spec_impl.ctl { + v.push((ctl, "ctl")); + } + if let Some(cta) = &spec_impl.ctl_adj { + v.push((cta, "ctl_adj")); + } + v + } + CallableImpl::SimulatableIntrinsic(spec) => { + vec![(spec, "sim")] + } + CallableImpl::Intrinsic => continue, + }; + + for (spec, label) in specs { + let mut expr_ids = FxHashSet::default(); + collect_expr_ids_in_block(package, spec.block, &mut expr_ids); + for eid in &expr_ids { + if let Some((prev_item, prev_label)) = seen.get(eid) { + panic!( + "PostDefunc ExprId uniqueness violation: {eid} appears in \ + both {prev_item}/{prev_label} and {}/{label}", + item_id.item, + ); + } + seen.insert(*eid, (item_id.item, label)); + } + } + } + + // Check entry expression ExprIds are disjoint from spec body ExprIds. + let mut entry_expr_ids = FxHashSet::default(); + collect_expr_ids_in_expr(package, entry_id, &mut entry_expr_ids); + for eid in &entry_expr_ids { + if let Some((owner_item, owner_label)) = seen.get(eid) { + panic!( + "PostDefunc entry/spec disjointness violation: {eid} appears in \ + both the entry expression and {owner_item}/{owner_label}", + ); + } + } +} + +/// Recursively collects all `ExprId`s reachable from a block. +fn collect_expr_ids_in_block(package: &Package, block_id: BlockId, ids: &mut FxHashSet) { + let block = package.get_block(block_id); + for &stmt_id in &block.stmts { + let stmt = package.get_stmt(stmt_id); + match &stmt.kind { + StmtKind::Expr(e) | StmtKind::Semi(e) | StmtKind::Local(_, _, e) => { + collect_expr_ids_in_expr(package, *e, ids); + } + StmtKind::Item(_) => {} + } + } +} + +/// Recursively collects all `ExprId`s reachable from an expression. +fn collect_expr_ids_in_expr(package: &Package, expr_id: ExprId, ids: &mut FxHashSet) { + ids.insert(expr_id); + crate::walk_utils::for_each_expr(package, expr_id, &mut |child_id, _| { + ids.insert(child_id); + }); +} diff --git a/source/compiler/qsc_fir_transforms/src/invariants/tests.rs b/source/compiler/qsc_fir_transforms/src/invariants/tests.rs new file mode 100644 index 0000000000..bd752b329c --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/invariants/tests.rs @@ -0,0 +1,1012 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use super::*; +use crate::test_utils::{PipelineStage, compile_and_run_pipeline_to}; +use crate::walk_utils; +use qsc_fir::fir::{ + CallableImpl, CallableKind, ExprId, ExprKind, Field, FieldPath, ItemKind, LocalItemId, + LocalVarId, PackageLookup, PatId, PatKind, Res, StmtKind, +}; +use qsc_fir::ty::{Arrow, FunctorSet, FunctorSetValue, ParamId, Prim}; + +/// Finds the first expression directly referenced by a statement in a +/// callable body within the package. The invariant checker visits these +/// expressions via `check_stmt_types`, so mutations here will be detected. +fn find_body_stmt_expr(pkg: &Package) -> ExprId { + for item in pkg.items.values() { + if let ItemKind::Callable(decl) = &item.kind + && let CallableImpl::Spec(spec_impl) = &decl.implementation + { + let block = pkg.get_block(spec_impl.body.block); + for &stmt_id in &block.stmts { + let stmt = pkg.get_stmt(stmt_id); + match &stmt.kind { + StmtKind::Expr(e) | StmtKind::Semi(e) | StmtKind::Local(_, _, e) => return *e, + StmtKind::Item(_) => {} + } + } + } + } + panic!("no statement-level expression found in package"); +} + +fn find_nested_expr_in_callable(pkg: &Package, mut predicate: F) -> ExprId +where + F: FnMut(&Package, ExprId, &qsc_fir::fir::Expr) -> bool, +{ + let stmt_roots = collect_stmt_expr_roots(pkg); + + for item in pkg.items.values() { + if let ItemKind::Callable(decl) = &item.kind { + let mut found = None; + walk_utils::for_each_expr_in_callable_impl( + pkg, + &decl.implementation, + &mut |expr_id, expr| { + if found.is_none() + && !stmt_roots.contains(&expr_id) + && predicate(pkg, expr_id, expr) + { + found = Some(expr_id); + } + }, + ); + + if let Some(expr_id) = found { + return expr_id; + } + } + } + + panic!("no nested expression found in package"); +} + +fn mutate_nested_expr_in_callable( + store: &mut PackageStore, + pkg_id: qsc_fir::fir::PackageId, + predicate: F, + mutate: M, +) where + F: FnMut(&Package, ExprId, &qsc_fir::fir::Expr) -> bool, + M: FnOnce(&mut Package, ExprId), +{ + let target_id = { + let pkg = store.get(pkg_id); + find_nested_expr_in_callable(pkg, predicate) + }; + + let pkg = store.get_mut(pkg_id); + mutate(pkg, target_id); +} + +fn find_expr_in_named_callable(pkg: &Package, callable_name: &str, mut predicate: F) -> ExprId +where + F: FnMut(&Package, ExprId, &qsc_fir::fir::Expr) -> bool, +{ + for item in pkg.items.values() { + if let ItemKind::Callable(decl) = &item.kind + && decl.name.name.as_ref() == callable_name + { + let mut found = None; + walk_utils::for_each_expr_in_callable_impl( + pkg, + &decl.implementation, + &mut |expr_id, expr| { + if found.is_none() && predicate(pkg, expr_id, expr) { + found = Some(expr_id); + } + }, + ); + + if let Some(expr_id) = found { + return expr_id; + } + } + } + + panic!("no matching expression found in callable '{callable_name}'"); +} + +fn find_local_tuple_pat(pkg: &Package) -> PatId { + for item in pkg.items.values() { + if let ItemKind::Callable(decl) = &item.kind + && let CallableImpl::Spec(spec_impl) = &decl.implementation + { + let block = pkg.get_block(spec_impl.body.block); + for &stmt_id in &block.stmts { + let stmt = pkg.get_stmt(stmt_id); + if let StmtKind::Local(_, pat_id, _) = stmt.kind + && matches!(pkg.get_pat(pat_id).kind, PatKind::Tuple(_)) + { + return pat_id; + } + } + } + } + + panic!("no tuple local pattern found in package"); +} + +fn find_callable_input_tuple_pat(pkg: &Package, callable_name: &str) -> PatId { + for item in pkg.items.values() { + if let ItemKind::Callable(decl) = &item.kind + && decl.name.name.as_ref() == callable_name + && matches!(pkg.get_pat(decl.input).kind, PatKind::Tuple(_)) + { + return decl.input; + } + } + + panic!("no tuple input pattern found for callable '{callable_name}'"); +} + +fn truncate_tuple_pat(pkg: &mut Package, pat_id: PatId) { + let PatKind::Tuple(sub_pats) = &pkg.get_pat(pat_id).kind else { + panic!("expected tuple pattern") + }; + assert!( + sub_pats.len() >= 2, + "tuple pattern must have at least two elements" + ); + + let mut truncated = sub_pats.clone(); + truncated.pop(); + + let pat = pkg.pats.get_mut(pat_id).expect("pat not found"); + pat.kind = PatKind::Tuple(truncated); +} + +fn collect_stmt_expr_roots(pkg: &Package) -> Vec { + let mut roots = Vec::new(); + + for item in pkg.items.values() { + if let ItemKind::Callable(decl) = &item.kind + && let CallableImpl::Spec(spec_impl) = &decl.implementation + { + collect_stmt_expr_roots_in_block(pkg, spec_impl.body.block, &mut roots); + for spec in crate::fir_builder::functored_specs(spec_impl) { + collect_stmt_expr_roots_in_block(pkg, spec.block, &mut roots); + } + } + } + + roots +} + +fn collect_stmt_expr_roots_in_block( + pkg: &Package, + block_id: qsc_fir::fir::BlockId, + roots: &mut Vec, +) { + let block = pkg.get_block(block_id); + for &stmt_id in &block.stmts { + let stmt = pkg.get_stmt(stmt_id); + match stmt.kind { + StmtKind::Expr(expr_id) | StmtKind::Semi(expr_id) | StmtKind::Local(_, _, expr_id) => { + roots.push(expr_id); + } + StmtKind::Item(_) => {} + } + } +} + +/// Replaces the first `Res::Local` reference in the package with one pointing +/// to `bad_id`, which should not be bound anywhere. The local-var consistency +/// check walks the entire callable body recursively, so any `Res::Local` is +/// reachable. +fn inject_stale_local_var( + store: &mut PackageStore, + pkg_id: qsc_fir::fir::PackageId, + bad_id: LocalVarId, +) { + let pkg = store.get_mut(pkg_id); + for expr in pkg.exprs.values_mut() { + if let ExprKind::Var(Res::Local(_), _) = &expr.kind { + expr.kind = ExprKind::Var(Res::Local(bad_id), vec![]); + return; + } + } + panic!("no Res::Local expression found to mutate"); +} + +fn inject_stale_local_var_in_callable( + store: &mut PackageStore, + pkg_id: qsc_fir::fir::PackageId, + callable_name: &str, + bad_id: LocalVarId, +) { + let target_id = { + let pkg = store.get(pkg_id); + find_expr_in_named_callable(pkg, callable_name, |_, _, expr| { + matches!(expr.kind, ExprKind::Var(Res::Local(_), _)) + }) + }; + + let pkg = store.get_mut(pkg_id); + let expr = pkg.exprs.get_mut(target_id).expect("expr not found"); + expr.kind = ExprKind::Var(Res::Local(bad_id), vec![]); +} + +fn inject_udt_expr_type_in_callable( + store: &mut PackageStore, + pkg_id: qsc_fir::fir::PackageId, + callable_name: &str, +) { + let target_id = { + let pkg = store.get(pkg_id); + find_expr_in_named_callable(pkg, callable_name, |_, _, _| true) + }; + + let pkg = store.get_mut(pkg_id); + let fake_item_id = qsc_fir::fir::ItemId { + package: pkg_id, + item: LocalItemId::from(0usize), + }; + let expr = pkg.exprs.get_mut(target_id).expect("expr not found"); + expr.ty = Ty::Udt(Res::Item(fake_item_id)); +} + +fn inject_local_tuple_pattern_arity_mismatch( + store: &mut PackageStore, + pkg_id: qsc_fir::fir::PackageId, +) { + let pat_id = { + let pkg = store.get(pkg_id); + find_local_tuple_pat(pkg) + }; + + let pkg = store.get_mut(pkg_id); + truncate_tuple_pat(pkg, pat_id); +} + +fn inject_callable_input_tuple_pattern_arity_mismatch( + store: &mut PackageStore, + pkg_id: qsc_fir::fir::PackageId, + callable_name: &str, +) { + let pat_id = { + let pkg = store.get(pkg_id); + find_callable_input_tuple_pat(pkg, callable_name) + }; + + let pkg = store.get_mut(pkg_id); + truncate_tuple_pat(pkg, pat_id); +} + +fn inject_call_argument_shape_mismatch( + store: &mut PackageStore, + pkg_id: qsc_fir::fir::PackageId, + callable_name: &str, +) { + let (call_expr_id, callee_id, mismatched_arg_id) = { + let pkg = store.get(pkg_id); + let call_expr_id = find_expr_in_named_callable( + pkg, + callable_name, + |pkg, _expr_id, expr| { + let ExprKind::Call(callee_id, arg_id) = expr.kind else { + return false; + }; + + matches!(call_input_ty(pkg, pkg_id, callee_id), Some(Ty::Tuple(_))) + && matches!(&pkg.get_expr(arg_id).kind, ExprKind::Tuple(elems) if !elems.is_empty()) + }, + ); + + let ExprKind::Call(callee_id, arg_id) = pkg.get_expr(call_expr_id).kind else { + panic!("expected call expression") + }; + let ExprKind::Tuple(elems) = &pkg.get_expr(arg_id).kind else { + panic!("expected tuple call argument") + }; + + (call_expr_id, callee_id, elems[0]) + }; + + let pkg = store.get_mut(pkg_id); + let call_expr = pkg + .exprs + .get_mut(call_expr_id) + .expect("call expr not found"); + call_expr.kind = ExprKind::Call(callee_id, mismatched_arg_id); +} + +fn call_input_ty(pkg: &Package, pkg_id: qsc_fir::fir::PackageId, callee_id: ExprId) -> Option { + let callee = pkg.get_expr(callee_id); + if let Ty::Arrow(arrow) = &callee.ty { + return Some((*arrow.input).clone()); + } + + if let ExprKind::Var(Res::Item(item_id), _) = &callee.kind + && item_id.package == pkg_id + { + let item = pkg.get_item(item_id.item); + if let ItemKind::Callable(decl) = &item.kind { + return Some(pkg.get_pat(decl.input).ty.clone()); + } + } + + None +} + +/// Changes the type of the entry expression to `Ty::Udt`. +fn inject_udt_expr_type(store: &mut PackageStore, pkg_id: qsc_fir::fir::PackageId) { + let pkg = store.get_mut(pkg_id); + let entry_id = pkg.entry.expect("package has no entry"); + let fake_item_id = qsc_fir::fir::ItemId { + package: pkg_id, + item: LocalItemId::from(0usize), + }; + let expr = pkg.exprs.get_mut(entry_id).expect("entry expr not found"); + expr.ty = Ty::Udt(Res::Item(fake_item_id)); +} + +/// Changes the output type of the first reachable callable to `Ty::Udt`. +fn inject_udt_callable_output(store: &mut PackageStore, pkg_id: qsc_fir::fir::PackageId) { + let pkg = store.get_mut(pkg_id); + let fake_item_id = qsc_fir::fir::ItemId { + package: pkg_id, + item: LocalItemId::from(0usize), + }; + for item in pkg.items.values_mut() { + if let ItemKind::Callable(decl) = &mut item.kind { + decl.output = Ty::Udt(Res::Item(fake_item_id)); + return; + } + } + panic!("no callable found to mutate"); +} + +/// Changes the type of the entry expression to `Ty::Arrow` with +/// `FunctorSet::Param`. +fn inject_functor_param_arrow(store: &mut PackageStore, pkg_id: qsc_fir::fir::PackageId) { + let pkg = store.get_mut(pkg_id); + let entry_id = pkg.entry.expect("package has no entry"); + let expr = pkg.exprs.get_mut(entry_id).expect("entry expr not found"); + expr.ty = Ty::Arrow(Box::new(Arrow { + kind: CallableKind::Operation, + input: Box::new(Ty::Prim(Prim::Int)), + output: Box::new(Ty::Prim(Prim::Int)), + functors: FunctorSet::Param(ParamId::from(0usize)), + })); +} + +/// Changes the type of the entry expression to `Ty::Param`. +fn inject_ty_param(store: &mut PackageStore, pkg_id: qsc_fir::fir::PackageId) { + let pkg = store.get_mut(pkg_id); + let entry_id = pkg.entry.expect("package has no entry"); + let expr = pkg.exprs.get_mut(entry_id).expect("entry expr not found"); + expr.ty = Ty::Param(ParamId::from(0usize)); +} + +/// Changes a statement-level body expression to `ExprKind::Closure`. +fn inject_closure_expr(store: &mut PackageStore, pkg_id: qsc_fir::fir::PackageId) { + let pkg = store.get_mut(pkg_id); + let target_id = find_body_stmt_expr(pkg); + let expr = pkg.exprs.get_mut(target_id).expect("expr not found"); + expr.ty = Ty::Arrow(Box::new(Arrow { + kind: CallableKind::Function, + input: Box::new(Ty::Prim(Prim::Int)), + output: Box::new(Ty::Prim(Prim::Int)), + functors: FunctorSet::Value(FunctorSetValue::Empty), + })); + expr.kind = ExprKind::Closure(vec![], LocalItemId::from(0usize)); +} + +/// Changes the type of the first callable's input pattern to `Ty::Arrow`. +fn inject_arrow_param(store: &mut PackageStore, pkg_id: qsc_fir::fir::PackageId) { + let pkg = store.get_mut(pkg_id); + let mut input_pat_id = None; + for item in pkg.items.values() { + if let ItemKind::Callable(decl) = &item.kind { + input_pat_id = Some(decl.input); + break; + } + } + let pat_id = input_pat_id.expect("no callable found"); + let pat = pkg.pats.get_mut(pat_id).expect("pat not found"); + pat.ty = Ty::Arrow(Box::new(Arrow { + kind: CallableKind::Operation, + input: Box::new(Ty::Prim(Prim::Int)), + output: Box::new(Ty::Prim(Prim::Int)), + functors: FunctorSet::Value(FunctorSetValue::Empty), + })); +} + +/// Changes the first local binding pattern to a nested tuple type containing an +/// arrow-typed field. +fn inject_nested_tuple_bound_arrow_local( + store: &mut PackageStore, + pkg_id: qsc_fir::fir::PackageId, +) { + let pkg = store.get_mut(pkg_id); + let mut local_pat_id = None; + + 'items: for item in pkg.items.values() { + if let ItemKind::Callable(decl) = &item.kind + && let CallableImpl::Spec(spec_impl) = &decl.implementation + { + let block = pkg.get_block(spec_impl.body.block); + for &stmt_id in &block.stmts { + let stmt = pkg.get_stmt(stmt_id); + if let StmtKind::Local(_, pat_id, _) = stmt.kind { + local_pat_id = Some(pat_id); + break 'items; + } + } + } + } + + let pat_id = local_pat_id.expect("no Local stmt found to mutate"); + let pat = pkg.pats.get_mut(pat_id).expect("pat not found"); + pat.ty = Ty::Tuple(vec![ + Ty::Tuple(vec![ + Ty::Arrow(Box::new(Arrow { + kind: CallableKind::Operation, + input: Box::new(Ty::Prim(Prim::Int)), + output: Box::new(Ty::Prim(Prim::Int)), + functors: FunctorSet::Value(FunctorSetValue::Empty), + })), + Ty::Prim(Prim::Int), + ]), + Ty::Prim(Prim::Int), + ]); +} + +/// Injects a non-copy `ExprKind::Struct` (copy slot = `None`) into a +/// statement-level body expression. +fn inject_non_copy_struct(store: &mut PackageStore, pkg_id: qsc_fir::fir::PackageId) { + let pkg = store.get_mut(pkg_id); + let target_id = find_body_stmt_expr(pkg); + let fake_item_id = qsc_fir::fir::ItemId { + package: pkg_id, + item: LocalItemId::from(0usize), + }; + let expr = pkg.exprs.get_mut(target_id).expect("expr not found"); + expr.kind = ExprKind::Struct(Res::Item(fake_item_id), None, vec![]); +} + +/// Simple Q# source with a local variable binding. +const SIMPLE_LOCAL_VAR: &str = r#" + namespace Test { + @EntryPoint() + function Main() : Int { + let x = 42; + x + } + } +"#; + +#[test] +fn invariant_passes_with_valid_local_var() { + let (store, pkg_id) = compile_and_run_pipeline_to(SIMPLE_LOCAL_VAR, PipelineStage::Mono); + check(&store, pkg_id, InvariantLevel::PostMono); +} + +#[test] +fn post_udt_erase_passes_when_no_udt_types() { + let (store, pkg_id) = compile_and_run_pipeline_to(SIMPLE_LOCAL_VAR, PipelineStage::UdtErase); + check(&store, pkg_id, InvariantLevel::PostUdtErase); +} + +#[test] +fn post_udt_erase_allows_copy_update_struct() { + let source = r#" + namespace Test { + struct Pair { Fst: Int, Snd: Int } + @EntryPoint() + function Main() : Int { + let p = new Pair { Fst = 1, Snd = 2 }; + let q = new Pair { ...p, Fst = 10 }; + q.Fst + } + } + "#; + let (store, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::UdtErase); + check(&store, pkg_id, InvariantLevel::PostUdtErase); +} + +#[test] +fn integration_post_udt_erase_invariant_passes() { + let source = r#" + namespace Test { + struct Pair { Fst: Int, Snd: Double } + @EntryPoint() + function Main() : (Int, Double) { + let p = new Pair { Fst = 1, Snd = 2.0 }; + (p.Fst, p.Snd) + } + } + "#; + let (store, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::UdtErase); + check(&store, pkg_id, InvariantLevel::PostUdtErase); +} + +#[test] +fn invariant_post_all_passes_after_full_pipeline() { + let source = r#" + namespace Test { + struct Pair { Fst: Int, Snd: Double } + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { op(q); } + @EntryPoint() + operation Main() : Unit { + let p = new Pair { Fst = 1, Snd = 2.0 }; + use q = Qubit(); + ApplyOp(H, q); + } + } + "#; + let (store, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::Full); + check(&store, pkg_id, InvariantLevel::PostAll); +} + +#[test] +#[should_panic(expected = "LocalVarId consistency")] +fn invariant_catches_stale_local_var() { + let (mut store, pkg_id) = compile_and_run_pipeline_to(SIMPLE_LOCAL_VAR, PipelineStage::Mono); + inject_stale_local_var(&mut store, pkg_id, LocalVarId::from(9999u32)); + check(&store, pkg_id, InvariantLevel::PostMono); +} + +#[test] +#[should_panic(expected = "Ty::Udt after UDT erasure")] +fn post_udt_erase_catches_remaining_udt_type() { + let (mut store, pkg_id) = + compile_and_run_pipeline_to(SIMPLE_LOCAL_VAR, PipelineStage::UdtErase); + inject_udt_expr_type(&mut store, pkg_id); + check(&store, pkg_id, InvariantLevel::PostUdtErase); +} + +#[test] +#[should_panic(expected = "ExprKind::Struct after UDT erasure")] +fn post_udt_erase_catches_non_copy_struct_expr() { + let (mut store, pkg_id) = + compile_and_run_pipeline_to(SIMPLE_LOCAL_VAR, PipelineStage::UdtErase); + inject_non_copy_struct(&mut store, pkg_id); + check(&store, pkg_id, InvariantLevel::PostUdtErase); +} + +#[test] +#[should_panic(expected = "Ty::Udt after UDT erasure")] +fn post_udt_erase_catches_udt_in_callable_output() { + let (mut store, pkg_id) = + compile_and_run_pipeline_to(SIMPLE_LOCAL_VAR, PipelineStage::UdtErase); + inject_udt_callable_output(&mut store, pkg_id); + check(&store, pkg_id, InvariantLevel::PostUdtErase); +} + +#[test] +#[should_panic(expected = "FunctorSet::Param after monomorphization")] +fn invariant_catches_functor_set_param_post_mono() { + let (mut store, pkg_id) = compile_and_run_pipeline_to(SIMPLE_LOCAL_VAR, PipelineStage::Mono); + inject_functor_param_arrow(&mut store, pkg_id); + check(&store, pkg_id, InvariantLevel::PostMono); +} + +#[test] +#[should_panic(expected = "Closure")] +fn invariant_post_defunc_catches_closure() { + let (mut store, pkg_id) = compile_and_run_pipeline_to(SIMPLE_LOCAL_VAR, PipelineStage::Defunc); + inject_closure_expr(&mut store, pkg_id); + check(&store, pkg_id, InvariantLevel::PostDefunc); +} + +#[test] +#[should_panic(expected = "Arrow")] +fn invariant_post_defunc_catches_arrow_param() { + // Need a callable with a named parameter (PatKind::Bind) so the + // arrow-type injection is caught by check_pat_for_arrow. + let source = r#" + namespace Test { + function Helper(x : Int) : Int { x } + @EntryPoint() + function Main() : Int { Helper(42) } + } + "#; + let (mut store, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::Defunc); + inject_arrow_param(&mut store, pkg_id); + check(&store, pkg_id, InvariantLevel::PostDefunc); +} + +#[test] +#[should_panic(expected = "tuple-bound local retains an arrow-typed field")] +fn post_sroa_catches_nested_tuple_bound_arrow() { + let source = r#" + namespace Test { + @EntryPoint() + function Main() : ((Int, Int), Int) { + let value = ((1, 2), 3); + value + } + } + "#; + let (mut store, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::Sroa); + inject_nested_tuple_bound_arrow_local(&mut store, pkg_id); + check(&store, pkg_id, InvariantLevel::PostSroa); +} + +#[test] +#[should_panic(expected = "Ty::Param")] +fn invariant_post_mono_catches_ty_param() { + let (mut store, pkg_id) = compile_and_run_pipeline_to(SIMPLE_LOCAL_VAR, PipelineStage::Mono); + inject_ty_param(&mut store, pkg_id); + check(&store, pkg_id, InvariantLevel::PostMono); +} + +/// Finds a statement-level expression and rewrites it as a +/// `Field::Path` whose record expression has `Ty::Prim(Int)` instead +/// of `Ty::Tuple`, triggering the `PostUdtErase` invariant violation. +fn inject_non_tuple_field_path_target(store: &mut PackageStore, pkg_id: qsc_fir::fir::PackageId) { + let pkg = store.get_mut(pkg_id); + let target_id = find_body_stmt_expr(pkg); + // Use the target as both the record and the outer expression—just + // change the outer's kind to Field::Path pointing at itself-like expr. + // We need a second expr to act as the "record". Pick any other expr. + let mut record_id = None; + for (eid, _) in &pkg.exprs { + if eid != target_id { + record_id = Some(eid); + break; + } + } + let record_id = record_id.expect("need at least two expressions"); + // Set the record expr to a non-tuple type. + let record = pkg.exprs.get_mut(record_id).expect("record expr not found"); + record.ty = Ty::Prim(Prim::Int); + // Rewrite the target as Field::Path referencing that record. + let target = pkg.exprs.get_mut(target_id).expect("expr not found"); + target.kind = ExprKind::Field(record_id, Field::Path(FieldPath::default())); + target.ty = Ty::Prim(Prim::Int); +} + +fn inject_nested_non_tuple_field_path_target( + store: &mut PackageStore, + pkg_id: qsc_fir::fir::PackageId, +) { + let (target_id, record_id) = { + let pkg = store.get(pkg_id); + let target_id = find_nested_expr_in_callable(pkg, |_, _, _| true); + let record_id = pkg + .exprs + .iter() + .find_map(|(expr_id, _)| (expr_id != target_id).then_some(expr_id)) + .expect("need at least two expressions"); + (target_id, record_id) + }; + + let pkg = store.get_mut(pkg_id); + let record = pkg.exprs.get_mut(record_id).expect("record expr not found"); + record.ty = Ty::Prim(Prim::Int); + + let target = pkg.exprs.get_mut(target_id).expect("expr not found"); + target.kind = ExprKind::Field(record_id, Field::Path(FieldPath::default())); + target.ty = Ty::Prim(Prim::Int); +} + +fn inject_nested_tuple_eq_in_if_branch(store: &mut PackageStore, pkg_id: qsc_fir::fir::PackageId) { + mutate_nested_expr_in_callable( + store, + pkg_id, + |pkg, _expr_id, expr| match &expr.kind { + ExprKind::Tuple(items) if items.len() == 2 => items + .iter() + .all(|item_id| matches!(pkg.get_expr(*item_id).ty, Ty::Tuple(_))), + _ => false, + }, + |pkg, target_id| { + let (lhs_id, rhs_id) = match &pkg.get_expr(target_id).kind { + ExprKind::Tuple(items) => (items[0], items[1]), + _ => panic!("nested target is not a tuple expression"), + }; + + let target = pkg.exprs.get_mut(target_id).expect("expr not found"); + target.kind = ExprKind::BinOp(BinOp::Eq, lhs_id, rhs_id); + target.ty = Ty::Prim(Prim::Bool); + }, + ); +} + +/// Finds a tuple expression in the package and changes its type to have a +/// different element count, triggering the tuple arity mismatch invariant. +fn inject_tuple_arity_mismatch(store: &mut PackageStore, pkg_id: qsc_fir::fir::PackageId) { + let pkg = store.get_mut(pkg_id); + for expr in pkg.exprs.values_mut() { + if let ExprKind::Tuple(es) = &expr.kind + && es.len() >= 2 + { + // Shrink the type tuple to have fewer elements than the expression. + expr.ty = Ty::Tuple(vec![Ty::Prim(Prim::Int); es.len() - 1]); + return; + } + } + panic!("no Tuple expression with >= 2 elements found to mutate"); +} + +fn convert_last_body_expr_to_semi(store: &mut PackageStore, pkg_id: qsc_fir::fir::PackageId) { + let pkg = store.get_mut(pkg_id); + for item in pkg.items.values() { + if let ItemKind::Callable(decl) = &item.kind + && let CallableImpl::Spec(spec_impl) = &decl.implementation + { + let block = pkg.blocks.get_mut(spec_impl.body.block).expect("block"); + let stmt_id = *block.stmts.last().expect("block should have stmts"); + let stmt = pkg.stmts.get_mut(stmt_id).expect("stmt not found"); + let StmtKind::Expr(expr_id) = stmt.kind else { + panic!("expected trailing Expr stmt") + }; + stmt.kind = StmtKind::Semi(expr_id); + return; + } + } + panic!("no callable body block found to mutate"); +} + +/// Finds a `StmtKind::Local` and changes the initializer expression's type +/// so it no longer matches the pattern type. +fn inject_binding_type_mismatch(store: &mut PackageStore, pkg_id: qsc_fir::fir::PackageId) { + let pkg = store.get_mut(pkg_id); + for item in pkg.items.values() { + if let ItemKind::Callable(decl) = &item.kind + && let CallableImpl::Spec(spec_impl) = &decl.implementation + { + let block = pkg.get_block(spec_impl.body.block); + for &stmt_id in &block.stmts { + let stmt = pkg.get_stmt(stmt_id); + if let StmtKind::Local(_, pat_id, expr_id) = &stmt.kind { + let pat_ty = &pkg.get_pat(*pat_id).ty; + if matches!(pat_ty, Ty::Prim(Prim::Int)) { + let init = pkg.exprs.get_mut(*expr_id).expect("init expr not found"); + init.ty = Ty::Prim(Prim::Double); + return; + } + } + } + } + } + panic!("no Local stmt with Prim(Int) pattern found to mutate"); +} + +/// Q# with a struct field access to ensure `Field::Path` survives the full pipeline. +const STRUCT_FIELD_ACCESS: &str = r#" + namespace Test { + struct Pair { Fst: Int, Snd: Double } + @EntryPoint() + function Main() : (Int, Double) { + let p = new Pair { Fst = 1, Snd = 2.0 }; + (p.Fst, p.Snd) + } + } +"#; + +const STRUCT_FIELD_ACCESS_INSIDE_IF: &str = r#" + namespace Test { + @EntryPoint() + function Main() : (Int, Double) { + if true { + (1, 2.0) + } else { + (0, 0.0) + } + } + } +"#; + +const PROMOTED_CALLABLE_INPUT: &str = r#" + namespace Test { + struct Pair { Fst: Int, Snd: Int } + + function Foo(p : Pair) : Int { + p.Fst + p.Snd + } + + @EntryPoint() + function Main() : Int { + Foo(new Pair { Fst = 1, Snd = 2 }) + } + } +"#; + +const PROMOTED_CALLABLE_VARIABLE_ARG: &str = r#" + namespace Test { + struct Pair { Fst: Int, Snd: Int } + + function Foo(p : Pair) : Int { + p.Fst + p.Snd + } + + @EntryPoint() + function Main() : Int { + let pair = new Pair { Fst = 1, Snd = 2 }; + Foo(pair) + } + } +"#; + +const NESTED_TUPLE_LITERAL_INSIDE_IF: &str = r#" + namespace Test { + @EntryPoint() + function Main() : ((Int, Int), (Int, Int)) { + if true { + ((1, 2), (3, 4)) + } else { + ((5, 6), (7, 8)) + } + } + } +"#; + +const SIMULATABLE_INTRINSIC_BODY: &str = r#" + namespace Test { + @SimulatableIntrinsic() + operation MyMeasurement(q : Qubit) : Result { + let r = M(q); + r + } + + @EntryPoint() + operation Main() : Result { + use q = Qubit(); + MyMeasurement(q) + } + } +"#; + +#[test] +fn post_all_field_path_on_tuple_passes() { + let (store, pkg_id) = compile_and_run_pipeline_to(STRUCT_FIELD_ACCESS, PipelineStage::Full); + check(&store, pkg_id, InvariantLevel::PostAll); +} + +#[test] +fn post_sroa_tuple_local_pattern_passes() { + let (store, pkg_id) = compile_and_run_pipeline_to(STRUCT_FIELD_ACCESS, PipelineStage::Sroa); + check(&store, pkg_id, InvariantLevel::PostSroa); +} + +#[test] +#[should_panic(expected = "Tuple pattern/type invariant violation")] +fn post_sroa_catches_tuple_local_pattern_arity_mismatch() { + let (mut store, pkg_id) = compile_and_run_pipeline_to(STRUCT_FIELD_ACCESS, PipelineStage::Sroa); + inject_local_tuple_pattern_arity_mismatch(&mut store, pkg_id); + check(&store, pkg_id, InvariantLevel::PostSroa); +} + +#[test] +fn post_arg_promote_tuple_input_pattern_passes() { + let (store, pkg_id) = + compile_and_run_pipeline_to(PROMOTED_CALLABLE_INPUT, PipelineStage::ArgPromote); + check(&store, pkg_id, InvariantLevel::PostArgPromote); +} + +#[test] +#[should_panic(expected = "Tuple pattern/type invariant violation")] +fn post_arg_promote_catches_callable_input_pattern_arity_mismatch() { + let (mut store, pkg_id) = + compile_and_run_pipeline_to(PROMOTED_CALLABLE_INPUT, PipelineStage::ArgPromote); + inject_callable_input_tuple_pattern_arity_mismatch(&mut store, pkg_id, "Foo"); + check(&store, pkg_id, InvariantLevel::PostArgPromote); +} + +#[test] +#[should_panic(expected = "LocalVarId consistency")] +fn post_mono_catches_stale_local_in_simulatable_intrinsic_body() { + let (mut store, pkg_id) = + compile_and_run_pipeline_to(SIMULATABLE_INTRINSIC_BODY, PipelineStage::Mono); + inject_stale_local_var_in_callable( + &mut store, + pkg_id, + "MyMeasurement", + LocalVarId::from(9999u32), + ); + check(&store, pkg_id, InvariantLevel::PostMono); +} + +#[test] +#[should_panic(expected = "contains Ty::Udt after UDT erasure")] +fn post_all_catches_simulatable_intrinsic_body_type_violation() { + let (mut store, pkg_id) = + compile_and_run_pipeline_to(SIMULATABLE_INTRINSIC_BODY, PipelineStage::Full); + inject_udt_expr_type_in_callable(&mut store, pkg_id, "MyMeasurement"); + check(&store, pkg_id, InvariantLevel::PostAll); +} + +#[test] +#[should_panic(expected = "Field::Path on non-tuple")] +fn post_all_field_path_on_non_tuple_panics() { + let (mut store, pkg_id) = compile_and_run_pipeline_to(STRUCT_FIELD_ACCESS, PipelineStage::Full); + inject_non_tuple_field_path_target(&mut store, pkg_id); + check(&store, pkg_id, InvariantLevel::PostAll); +} + +#[test] +#[should_panic(expected = "Field::Path on non-tuple")] +fn post_all_catches_nested_field_path_on_non_tuple_inside_if_branch() { + let (mut store, pkg_id) = + compile_and_run_pipeline_to(STRUCT_FIELD_ACCESS_INSIDE_IF, PipelineStage::Full); + inject_nested_non_tuple_field_path_target(&mut store, pkg_id); + check(&store, pkg_id, InvariantLevel::PostAll); +} + +#[test] +fn post_all_binding_type_consistency_passes() { + let (store, pkg_id) = compile_and_run_pipeline_to(SIMPLE_LOCAL_VAR, PipelineStage::Full); + check(&store, pkg_id, InvariantLevel::PostAll); +} + +#[test] +#[should_panic(expected = "PostReturnUnify invariant violation: local binding")] +fn post_all_binding_type_mismatch_panics() { + let (mut store, pkg_id) = compile_and_run_pipeline_to(SIMPLE_LOCAL_VAR, PipelineStage::Full); + inject_binding_type_mismatch(&mut store, pkg_id); + check(&store, pkg_id, InvariantLevel::PostAll); +} + +#[test] +#[should_panic(expected = "PostArgPromote/PostAll call invariant violation")] +fn post_all_catches_call_argument_shape_mismatch() { + let (mut store, pkg_id) = + compile_and_run_pipeline_to(PROMOTED_CALLABLE_VARIABLE_ARG, PipelineStage::Full); + inject_call_argument_shape_mismatch(&mut store, pkg_id, "Main"); + check(&store, pkg_id, InvariantLevel::PostAll); +} + +#[test] +#[should_panic(expected = "Tuple arity mismatch")] +fn post_defunc_catches_tuple_arity_mismatch() { + let source = r#" + namespace Test { + @EntryPoint() + function Main() : (Int, Int, Int) { + (1, 2, 3) + } + } + "#; + let (mut store, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::Defunc); + inject_tuple_arity_mismatch(&mut store, pkg_id); + check(&store, pkg_id, InvariantLevel::PostDefunc); +} + +#[test] +#[should_panic(expected = "Non-Unit block-tail invariant violation")] +fn post_defunc_catches_non_unit_block_tail_violation() { + let (mut store, pkg_id) = compile_and_run_pipeline_to(SIMPLE_LOCAL_VAR, PipelineStage::Defunc); + convert_last_body_expr_to_semi(&mut store, pkg_id); + check(&store, pkg_id, InvariantLevel::PostDefunc); +} + +#[test] +#[should_panic(expected = "PostTupleCompLower invariant violation")] +fn post_tuple_comp_lower_catches_nested_tuple_eq_inside_if_branch() { + let (mut store, pkg_id) = compile_and_run_pipeline_to( + NESTED_TUPLE_LITERAL_INSIDE_IF, + PipelineStage::TupleCompLower, + ); + inject_nested_tuple_eq_in_if_branch(&mut store, pkg_id); + check(&store, pkg_id, InvariantLevel::PostTupleCompLower); +} + +/// Injects a non-existent `StmtId` into the first callable body block's +/// statement list, triggering the ID reference check. +fn inject_dangling_stmt_id(store: &mut PackageStore, pkg_id: qsc_fir::fir::PackageId) { + let pkg = store.get_mut(pkg_id); + for item in pkg.items.values() { + if let ItemKind::Callable(decl) = &item.kind + && let CallableImpl::Spec(spec_impl) = &decl.implementation + { + let block = pkg.blocks.get_mut(spec_impl.body.block).expect("block"); + // Use a StmtId far beyond any that could exist. + block.stmts.push(qsc_fir::fir::StmtId::from(99999u32)); + return; + } + } + panic!("no callable with body block found to mutate"); +} + +#[test] +#[should_panic(expected = "references nonexistent Stmt")] +fn invariant_catches_dangling_stmt_id_in_block() { + let (mut store, pkg_id) = compile_and_run_pipeline_to(SIMPLE_LOCAL_VAR, PipelineStage::Full); + inject_dangling_stmt_id(&mut store, pkg_id); + check(&store, pkg_id, InvariantLevel::PostAll); +} diff --git a/source/compiler/qsc_fir_transforms/src/item_dce.rs b/source/compiler/qsc_fir_transforms/src/item_dce.rs new file mode 100644 index 0000000000..0ddab4d032 --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/item_dce.rs @@ -0,0 +1,100 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! Item-level dead code elimination. +//! +//! After monomorphization and defunctionalization, many items become +//! unreachable: original generic callables replaced by monomorphized copies, +//! closure items fully specialized, etc. These unreachable items remain in +//! [`Package::items`](qsc_fir::fir::Package). This pass removes them. +//! +//! # Separation from `gc_unreachable` +//! +//! [`gc_unreachable`](crate::gc_unreachable) operates on arena nodes (blocks, +//! stmts, exprs, pats) within a single package. Item-level reachability is +//! cross-package (library items may be referenced from user code), so it +//! requires a [`PackageStore`](qsc_fir::fir::PackageStore) for the +//! reachability walk. This is why item DCE is a separate pass. +//! +//! # `StmtKind::Item` edge case +//! +//! `StmtKind::Item(local_item_id)` stmts declare items inside blocks. If item +//! DCE removes an item but its declaring `StmtKind::Item` stmt is still in a +//! reachable block, `invariants::check_id_references` will panic. The +//! pipeline mitigates this by re-running `gc_unreachable` after item DCE when +//! any items are removed — this tombstones the arena nodes (blocks, stmts, +//! exprs, pats) that belonged to the deleted items' bodies. The +//! `StmtKind::Item` stmts themselves survive as dangling references, which is +//! safe because `check_id_references` explicitly allows them post-DCE and +//! `exec_graph_rebuild` ignores `StmtKind::Item` stmts. +//! +//! # Transformation shape +//! +//! **Before:** `Package::items` contains unreachable callable items (original +//! generics replaced by monomorphized copies, fully-specialized closure items) +//! and dead type items left after UDT erasure. +//! +//! **After:** Unreachable items are removed from `Package::items`. If any +//! items were removed, `gc_unreachable` re-runs to tombstone their arena +//! nodes. + +#[cfg(test)] +mod tests; + +use qsc_fir::fir::{ItemKind, LocalItemId, Package, PackageId, Res, StoreItemId}; +use rustc_hash::FxHashSet; + +/// Eliminates unreachable items from the package's item map. +/// +/// The `reachable` set should be the output of +/// [`collect_reachable_from_entry`](crate::reachability::collect_reachable_from_entry). +/// Only items local to this package are considered; cross-package items in the +/// reachable set are ignored. +/// +/// Type items are unconditionally removed (dead after `udt_erase`). Namespace +/// and export items are structural and always preserved. +/// +/// Export targets that resolve to local callables are marked reachable so the +/// preserved exports cannot point at removed items. +/// +/// Returns the number of items removed. +#[allow(clippy::implicit_hasher)] +pub fn eliminate_dead_items( + package_id: PackageId, + package: &mut Package, + reachable: &FxHashSet, +) -> usize { + let mut local_reachable: FxHashSet = reachable + .iter() + .filter(|id| id.package == package_id) + .map(|id| id.item) + .collect(); + + // Mark export targets that resolve to local callables as reachable so + // the preserved exports don't point at removed items. Cross-package + // export targets and unresolved (Res::Err) exports are ignored. + for item in package.items.values() { + if let ItemKind::Export(_name, Res::Item(item_id)) = &item.kind + && item_id.package == package_id + { + local_reachable.insert(item_id.item); + } + } + + let mut removed = 0; + package.items.retain(|id, item| { + let keep = match &item.kind { + // Callable items: keep only if reachable from entry or an export target. + ItemKind::Callable(_) => local_reachable.contains(&id), + // Type items: unconditionally dead after `udt_erase`. + ItemKind::Ty(..) => false, + // Namespace and export items: structural, always preserved. + ItemKind::Namespace(..) | ItemKind::Export(..) => true, + }; + if !keep { + removed += 1; + } + keep + }); + removed +} diff --git a/source/compiler/qsc_fir_transforms/src/item_dce/tests.rs b/source/compiler/qsc_fir_transforms/src/item_dce/tests.rs new file mode 100644 index 0000000000..9e1355f643 --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/item_dce/tests.rs @@ -0,0 +1,712 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use crate::PipelineStage; +use crate::test_utils::{compile_and_run_pipeline_to, compile_to_fir}; +use indoc::indoc; +use qsc_data_structures::span::Span; +use qsc_fir::assigner::Assigner; +use qsc_fir::fir::{Ident, Item, ItemId, ItemKind, LocalVarId, PackageLookup, Res, Visibility}; +use std::rc::Rc; + +/// Counts total items in the user package. +fn item_count(package: &qsc_fir::fir::Package) -> usize { + package.items.iter().count() +} + +/// Counts callable items in the user package. +fn callable_count(package: &qsc_fir::fir::Package) -> usize { + package + .items + .iter() + .filter(|(_, item)| matches!(item.kind, ItemKind::Callable(_))) + .count() +} + +fn callable_id_by_name(package: &qsc_fir::fir::Package, name: &str) -> qsc_fir::fir::LocalItemId { + package + .items + .iter() + .find_map(|(item_id, item)| match &item.kind { + ItemKind::Callable(decl) if decl.name.name.as_ref() == name => Some(item_id), + _ => None, + }) + .unwrap_or_else(|| panic!("callable {name} should exist")) +} + +fn make_export_item( + export_id: qsc_fir::fir::LocalItemId, + package_id: qsc_fir::fir::PackageId, + target_id: qsc_fir::fir::LocalItemId, +) -> Item { + Item { + id: export_id, + span: Span::default(), + parent: None, + doc: Rc::from(""), + attrs: vec![], + visibility: Visibility::Public, + kind: ItemKind::Export( + Ident { + id: LocalVarId::default(), + span: Span::default(), + name: Rc::from("ExportedHelper"), + }, + Res::Item(ItemId { + package: package_id, + item: target_id, + }), + ), + } +} + +#[test] +fn dce_removes_unreachable_generic_after_monomorphize() { + // After monomorphization, the original generic callable is unreachable + // because it has been replaced by monomorphized copies. + let source = indoc! {" + namespace Test { + function Id<'T>(x : 'T) : 'T { x } + @EntryPoint() + function Main() : Int { Id(42) } + } + "}; + let (store_before, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::Gc); + let items_before = item_count(store_before.get(pkg_id)); + + let (store_after, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::ItemDce); + let items_after = item_count(store_after.get(pkg_id)); + + assert!( + items_after < items_before, + "item DCE should remove unreachable items: before={items_before}, after={items_after}" + ); +} + +#[test] +fn dce_preserves_all_reachable_items() { + // A minimal program where every callable item is reachable. + let source = indoc! {" + namespace Test { + @EntryPoint() + function Main() : Int { 42 } + } + "}; + let (store_before, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::Gc); + let callable_count_before = callable_count(store_before.get(pkg_id)); + + let (store_after, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::ItemDce); + let callable_count_after = callable_count(store_after.get(pkg_id)); + + assert_eq!( + callable_count_before, callable_count_after, + "all callables reachable — nothing should be removed" + ); +} + +#[test] +fn dce_with_closure_passes_invariants() { + // Closures produce StmtKind::Item declarations in outer blocks. + // After defunc these become specialized items; the original closure item + // may become unreachable. ItemDce + cascading GC should keep invariants + // clean. + let source = indoc! {" + namespace Test { + function Apply(f : Int -> Int, x : Int) : Int { f(x) } + @EntryPoint() + function Main() : Int { Apply(x -> x + 1, 5) } + } + "}; + // Running through Full exercises ItemDce + cascading GC + invariants. + let (_store, _pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::Full); + // If we reach here, post-DCE invariants (including check_id_references) + // passed after cascading GC cleaned up any orphaned StmtKind::Item stmts. +} + +#[test] +fn dce_on_entry_less_package_is_noop() { + // Library packages have no entry expression. The pipeline guards against + // calling collect_reachable_from_entry (which panics) on entry-less + // packages. Verify the guard works by running the full pipeline — core + // and std are entry-less, and they must survive untouched. + let source = indoc! {" + namespace Test { + @EntryPoint() + function Main() : Unit {} + } + "}; + let (store, _pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::Full); + // The core package has no entry expression and should still have items. + let core_id = qsc_fir::fir::PackageId::CORE; + assert!( + store.get(core_id).entry.is_none(), + "core package should have no entry expression" + ); + assert!( + item_count(store.get(core_id)) > 0, + "core package items should be untouched by item DCE" + ); +} + +#[test] +fn dce_removes_generic_after_pipeline() { + // Non-trivial program exercising multiple transform passes. + // After ItemDce, unreachable original generic callables should be removed. + let source = indoc! {" + namespace Test { + function Id<'T>(x : 'T) : 'T { x } + operation ApplyOp(q : Qubit, op : Qubit => Unit) : Unit { op(q); } + @EntryPoint() + operation Main() : Unit { + let x = Id(42); + use q = Qubit(); + ApplyOp(q, H); + if M(q) == One { + X(q); + } + Reset(q); + } + } + "}; + let (store, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::ItemDce); + // Verify the original generic Id callable was removed — the monomorphized + // copy Id should remain. + let package = store.get(pkg_id); + let remaining_names: Vec<_> = package + .items + .iter() + .filter_map(|(_, item)| match &item.kind { + ItemKind::Callable(decl) => Some(decl.name.name.to_string()), + _ => None, + }) + .collect(); + assert!( + !remaining_names.iter().any(|n| n == "Id"), + "generic Id should be removed; remaining: {remaining_names:?}" + ); + assert!( + remaining_names.iter().any(|n| n.starts_with("Id<")), + "monomorphized Id should survive; remaining: {remaining_names:?}" + ); +} + +#[test] +fn dce_benchmark_generic_multiple_instantiations() { + let source = indoc! {" + namespace Test { + function Id<'T>(x : 'T) : 'T { x } + function Wrap<'T>(x : 'T) : 'T { Id(x) } + @EntryPoint() + function Main() : Int { Wrap(42) + Wrap(0) } + } + "}; + let (store_before, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::Gc); + let items_before = item_count(store_before.get(pkg_id)); + + let (store_after, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::ItemDce); + let items_after = item_count(store_after.get(pkg_id)); + + assert!( + items_after < items_before, + "DCE should reduce items: before={items_before}, after={items_after}" + ); + let callables_after = callable_count(store_after.get(pkg_id)); + assert!( + callables_after > 0, + "monomorphized callables should survive: {callables_after}" + ); +} + +#[test] +fn dce_benchmark_type_declarations_removed() { + let source = indoc! {" + namespace Test { + newtype Pair = (First : Int, Second : Int); + @EntryPoint() + function Main() : Int { + let p = Pair(1, 2); + p::First + p::Second + } + } + "}; + let (store_before, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::Gc); + let items_before = item_count(store_before.get(pkg_id)); + + let (store_after, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::ItemDce); + let items_after = item_count(store_after.get(pkg_id)); + + assert!( + items_after < items_before, + "DCE should remove type items: before={items_before}, after={items_after}" + ); +} + +#[test] +fn dce_benchmark_closure_and_generic() { + let source = indoc! {" + namespace Test { + function Apply<'T>(f : 'T -> 'T, x : 'T) : 'T { f(x) } + @EntryPoint() + function Main() : Int { Apply(x -> x + 1, 5) } + } + "}; + let (store_before, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::Gc); + let items_before = item_count(store_before.get(pkg_id)); + + let (store_after, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::ItemDce); + let items_after = item_count(store_after.get(pkg_id)); + + assert!( + items_after < items_before, + "DCE should reduce items with closures+generics: before={items_before}, after={items_after}" + ); + let callables_after = callable_count(store_after.get(pkg_id)); + assert!( + callables_after > 0, + "specialized callables should survive: {callables_after}" + ); +} + +#[test] +fn dce_preserves_namespace_items() { + let source = indoc! {" + namespace Test { + @EntryPoint() + function Main() : Int { 42 } + } + "}; + let (store, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::ItemDce); + let package = store.get(pkg_id); + let has_namespace = package + .items + .iter() + .any(|(_, item)| matches!(item.kind, ItemKind::Namespace(..))); + assert!(has_namespace, "namespace items must survive DCE"); +} + +#[test] +fn dce_preserves_export_targets() { + let source = indoc! {" + namespace Test { + function Helper() : Int { 42 } + function Dead() : Int { 0 } + @EntryPoint() + function Main() : Int { 1 } + } + "}; + let (mut store, pkg_id) = crate::test_utils::compile_to_fir(source); + let helper_id = callable_id_by_name(store.get(pkg_id), "Helper"); + let dead_id = callable_id_by_name(store.get(pkg_id), "Dead"); + let mut assigner = Assigner::from_package(store.get(pkg_id)); + let export_id = assigner.next_item(); + + store + .get_mut(pkg_id) + .items + .insert(export_id, make_export_item(export_id, pkg_id, helper_id)); + + let reachable = crate::reachability::collect_reachable_from_entry(&store, pkg_id); + assert!( + !reachable.contains(&qsc_fir::fir::StoreItemId { + package: pkg_id, + item: helper_id, + }), + "Helper should be unreachable except through the export" + ); + + crate::item_dce::eliminate_dead_items(pkg_id, store.get_mut(pkg_id), &reachable); + let package = store.get(pkg_id); + + assert!( + package.items.contains_key(helper_id), + "export target callable should survive DCE" + ); + assert!( + !package.items.contains_key(dead_id), + "unexported unreachable callable should still be removed" + ); + + let export = package.get_item(export_id); + let ItemKind::Export(_, Res::Item(target)) = &export.kind else { + panic!("export item should survive with an item target") + }; + assert_eq!(target.package, pkg_id); + assert_eq!(target.item, helper_id); + assert!( + package.items.contains_key(target.item), + "export target should not dangle after DCE" + ); +} + +#[test] +fn item_dce_is_idempotent() { + let source = indoc! {" + namespace Test { + function Id<'T>(x : 'T) : 'T { x } + @EntryPoint() + function Main() : Int { Id(42) } + } + "}; + let (mut store, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::ItemDce); + let items_after_first = item_count(store.get(pkg_id)); + + let reachable = crate::reachability::collect_reachable_from_entry(&store, pkg_id); + let removed = crate::item_dce::eliminate_dead_items(pkg_id, store.get_mut(pkg_id), &reachable); + assert_eq!(removed, 0, "second item_dce run should remove nothing"); + assert_eq!( + item_count(store.get(pkg_id)), + items_after_first, + "item count should be unchanged after second item_dce run" + ); +} + +/// Tests validating `item_dce`'s fragile contract regarding temporary dangling +/// `StmtKind::Item` references and export retention. +/// +/// # Contract Summary +/// +/// After `item_dce` removes dead items, the declaring `StmtKind::Item` statements +/// may remain in reachable blocks, creating temporary dangling references. This is +/// **intentional and safe** because: +/// +/// - **`check_id_references` explicitly allows dangling `StmtKind::Item` references +/// post-DCE.** See [`crate::invariants::check_id_references`] for details. +/// - **`exec_graph_rebuild` ignores `StmtKind::Item` statements**, so dangling refs +/// never participate in execution-graph construction. +/// - **The pipeline cascades `gc_unreachable` after `item_dce`** to tombstone arena +/// nodes belonging to deleted items. This repairs the dangling references by +/// cleaning up the statements. +/// +/// This is a **staged-invariant design**: `item_dce` operates only at the item +/// (declaration) level; node-level (block/stmt/expr arena) cleanup is deferred to +/// the downstream garbage-collection pass. Export targets that resolve to local +/// callables are marked reachable by `item_dce` to prevent dangling exports, while +/// unresolved exports are unconditionally preserved. +mod item_dce_contracts { + use super::*; + + fn dangling_item_refs(package: &qsc_fir::fir::Package) -> Vec { + let mut refs = Vec::new(); + for stmt in package.stmts.values() { + if let qsc_fir::fir::StmtKind::Item(item_id) = &stmt.kind + && package.items.get(*item_id).is_none() + { + refs.push(*item_id); + } + } + refs.sort(); + refs + } + + fn insert_item_stmt_in_main( + store: &mut qsc_fir::fir::PackageStore, + pkg_id: qsc_fir::fir::PackageId, + assigner: &mut Assigner, + item_id: qsc_fir::fir::LocalItemId, + ) { + let stmt_id = assigner.next_stmt(); + let package = store.get_mut(pkg_id); + package.stmts.insert( + stmt_id, + qsc_fir::fir::Stmt { + id: stmt_id, + span: Span::default(), + kind: qsc_fir::fir::StmtKind::Item(item_id), + exec_graph_range: crate::EMPTY_EXEC_RANGE, + }, + ); + + let main_id = callable_id_by_name(package, "Main"); + let main_item = package.get_item(main_id); + let ItemKind::Callable(main_decl) = &main_item.kind else { + panic!("Main should be callable"); + }; + let qsc_fir::fir::CallableImpl::Spec(spec) = &main_decl.implementation else { + panic!("Main should have a body spec"); + }; + let main_block = spec.body.block; + package + .blocks + .get_mut(main_block) + .expect("Main body block should exist") + .stmts + .insert(0, stmt_id); + } + + /// Validates that `item_dce` removes dead callables while preserving the + /// pipeline's ability to handle temporary dangling `StmtKind::Item` references. + /// + /// # Contract Being Tested + /// + /// - Dead callables are removed from `Package::items`. + /// - A dead callable declared via `StmtKind::Item` in a reachable block + /// becomes a dangling reference temporarily. + /// - The dangling reference is safe: `check_id_references` post-DCE allows it, and + /// `exec_graph_rebuild` ignores `StmtKind::Item` statements. + /// - The pipeline repairs it by cascading `gc_unreachable` after `item_dce`. + #[test] + fn test_temporary_dangling_refs_allowed() { + let source = indoc! {" + namespace Test { + function Dead() : Int { 0 } + @EntryPoint() + function Main() : Int { 42 } + } + "}; + + let (mut store, pkg_id) = compile_to_fir(source); + let mut assigner = Assigner::from_package(store.get(pkg_id)); + crate::monomorphize::monomorphize(&mut store, pkg_id, &mut assigner); + let dead_id = callable_id_by_name(store.get(pkg_id), "Dead"); + insert_item_stmt_in_main(&mut store, pkg_id, &mut assigner, dead_id); + assert!( + dangling_item_refs(store.get(pkg_id)).is_empty(), + "pre-DCE package should not yet contain dangling item refs" + ); + + // Directly invoke item_dce without cascading gc_unreachable. + let reachable = crate::reachability::collect_reachable_from_entry(&store, pkg_id); + let removed = + crate::item_dce::eliminate_dead_items(pkg_id, store.get_mut(pkg_id), &reachable); + + // Verify the dead item was removed. + assert!( + removed > 0, + "dead callable should have been removed by item_dce" + ); + + assert!( + !dangling_item_refs(store.get(pkg_id)).is_empty(), + "direct item_dce should leave a temporary dangling StmtKind::Item ref" + ); + + // Verify that reachable items (Main) still exist. + let package = store.get(pkg_id); + let has_main = package.items.iter().any(|(_, item)| { + matches!(&item.kind, ItemKind::Callable(decl) if decl.name.name.as_ref() == "Main") + }); + assert!( + has_main, + "reachable callable 'Main' should survive item_dce" + ); + + crate::invariants::check(&store, pkg_id, crate::invariants::InvariantLevel::PostGc); + } + + /// Validates that `item_dce` preserves exports and marks their resolution targets as + /// reachable, preventing dangling export targets. + /// + /// # Contract Being Tested + /// + /// - Export items (structural) are always preserved. + /// - Export targets that resolve to local callables are marked reachable so the + /// preserved export cannot point at a removed item. + /// - Unresolved export targets (`Res::Err`) are tolerated and do not cause removal + /// of the export itself. + #[test] + fn test_export_retention_with_unresolved_targets() { + let source = indoc! {" + namespace Test { + function Helper() : Int { 42 } + @EntryPoint() + function Main() : Int { 1 } + } + "}; + + // Compile to FIR and monomorphize. + let (mut store, pkg_id) = compile_to_fir(source); + let mut assigner = Assigner::from_package(store.get(pkg_id)); + crate::monomorphize::monomorphize(&mut store, pkg_id, &mut assigner); + + // Manually create an export with an unresolved target to validate the contract. + let export_id = assigner.next_item(); + store.get_mut(pkg_id).items.insert( + export_id, + Item { + id: export_id, + span: Span::default(), + parent: None, + doc: Rc::from(""), + attrs: vec![], + visibility: Visibility::Public, + kind: ItemKind::Export( + Ident { + id: LocalVarId::default(), + span: Span::default(), + name: Rc::from("UnresolvedExport"), + }, + Res::Err, // Unresolved target + ), + }, + ); + + let items_before = item_count(store.get(pkg_id)); + + // Run item_dce. + let reachable = crate::reachability::collect_reachable_from_entry(&store, pkg_id); + crate::item_dce::eliminate_dead_items(pkg_id, store.get_mut(pkg_id), &reachable); + + let package = store.get(pkg_id); + + // Contract validation 1: export items are always preserved. + assert!( + package.items.contains_key(export_id), + "export with unresolved target must be retained" + ); + + // Contract validation 2: export structure is unchanged. + let ItemKind::Export(export_name, export_res) = &package.get_item(export_id).kind else { + panic!("export_id should still be an export item"); + }; + assert_eq!( + export_name.name.as_ref(), + "UnresolvedExport", + "export name should be preserved" + ); + assert!( + matches!(export_res, Res::Err), + "unresolved target should remain unresolved after item_dce" + ); + + // Verify that DCE still removes truly dead items (any garbage not exported or reachable). + // The items_before count includes the unresolved export, Main, and possibly others. + // We just verify the export survived; DCE logic is tested elsewhere. + assert!( + item_count(store.get(pkg_id)) <= items_before, + "item count should not increase after item_dce" + ); + } + + #[test] + fn dce_surviving_stmtitem_refs_are_valid() { + // Regression test: Verify StmtKind::Item refs point to valid items after DCE. + // + // Invariant: After item DCE, all surviving StmtKind::Item references within + // reachable callable bodies must reference items that still exist in the package. + // No dangling references should remain. + let source = indoc! {" + namespace Test { + operation Dead() : Unit { } + operation Alive() : Unit { + Dead(); + } + @EntryPoint() + operation Main() : Unit { + Alive(); + } + } + "}; + + let (store, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::ItemDce); + let package = store.get(pkg_id); + + // Collect all reachable items + let reachable = crate::reachability::collect_reachable_from_entry(&store, pkg_id); + let reachable_local: Vec<_> = reachable + .iter() + .filter_map(|id| { + if id.package == pkg_id { + Some(id.item) + } else { + None + } + }) + .collect(); + + // Verify: For each reachable callable, all StmtKind::Item refs point to valid items + for local_item_id in reachable_local { + if let ItemKind::Callable(callable) = &package.get_item(local_item_id).kind { + let spec = match &callable.implementation { + qsc_fir::fir::CallableImpl::Spec(spec_impl) => &spec_impl.body, + qsc_fir::fir::CallableImpl::SimulatableIntrinsic(spec) => spec, + qsc_fir::fir::CallableImpl::Intrinsic => continue, + }; + + // Collect all statements in the callable body block + let block = package.get_block(spec.block); + for stmt_id in &block.stmts { + let stmt = package.get_stmt(*stmt_id); + if let qsc_fir::fir::StmtKind::Item(item_ref) = &stmt.kind { + assert!( + package.items.contains_key(*item_ref), + "StmtKind::Item reference {item_ref:?} points to non-existent item after DCE" + ); + } + } + } + } + } +} + +#[test] +fn pinned_item_survives_item_dce() { + let (mut store, pkg_id) = compile_to_fir(indoc! {" + namespace Test { + @EntryPoint() + operation Main() : Int { 42 } + // Unreachable from entry but will be pinned + operation Pinned() : Int { 99 } + } + "}); + let package = store.get(pkg_id); + let pinned_local = callable_id_by_name(package, "Pinned"); + let pinned_store_id = qsc_fir::fir::StoreItemId { + package: pkg_id, + item: pinned_local, + }; + + let errors = crate::run_pipeline_to( + &mut store, + pkg_id, + PipelineStage::ItemDce, + &[pinned_store_id], + ); + assert!(errors.is_empty()); + + // Pinned item should survive DCE. + let package = store.get(pkg_id); + assert!( + package.items.get(pinned_local).is_some(), + "pinned item should survive DCE" + ); +} + +#[test] +fn pinned_item_transitive_deps_survive_item_dce() { + let (mut store, pkg_id) = compile_to_fir(indoc! {" + namespace Test { + @EntryPoint() + operation Main() : Int { 42 } + // Unreachable from entry but will be pinned + operation Pinned() : Int { Helper() } + // Transitive dep of Pinned, also unreachable from entry + operation Helper() : Int { 77 } + } + "}); + let package = store.get(pkg_id); + let pinned_local = callable_id_by_name(package, "Pinned"); + let helper_local = callable_id_by_name(package, "Helper"); + let pinned_store_id = qsc_fir::fir::StoreItemId { + package: pkg_id, + item: pinned_local, + }; + + let errors = crate::run_pipeline_to( + &mut store, + pkg_id, + PipelineStage::ItemDce, + &[pinned_store_id], + ); + assert!(errors.is_empty()); + + // Both pinned item and its transitive dep should survive DCE. + let package = store.get(pkg_id); + assert!( + package.items.get(pinned_local).is_some(), + "pinned item should survive DCE" + ); + assert!( + package.items.get(helper_local).is_some(), + "transitive dependency of pinned item should survive DCE" + ); +} diff --git a/source/compiler/qsc_fir_transforms/src/lib.rs b/source/compiler/qsc_fir_transforms/src/lib.rs new file mode 100644 index 0000000000..76031c1817 --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/lib.rs @@ -0,0 +1,312 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! FIR-to-FIR transformation passes for the Q# compiler. +//! +//! The FIR transform pipeline should run after FIR lowering and before +//! partial evaluation and codegen. It is responsible for monomorphizing +//! generics, eliminating callable-valued expressions, erasing UDTs, and +//! performing various structural rewrites that simplify later stages. +//! The transformations in this crate are not intended to be used as +//! independent passes. Instead, they are ordered and orchestrated by the +//! `run_pipeline` function, which applies the full sequence of +//! transformations in one shot. This is because the passes are not designed +//! to be individually sound or to preserve FIR invariants on their own. +//! For example, defunctionalization produces FIR that violates invariants +//! expected by later passes, but the subsequent UDT erasure and tuple +//! comparison lowering restore those invariants before the next major +//! stage (SROA). +//! +//! At the end of the pipeline, the FIR should be in a form that is +//! semantically equivalent to the input but more amenable to partial +//! evaluation and codegen. +//! +//! This crate defines the production FIR rewrite schedule that runs after FIR +//! lowering. The pipeline monomorphizes reachable callables, rewrites returns +//! to a single-exit form, defunctionalizes callable values, erases UDTs, +//! lowers non-empty tuple +//! equality and inequality, scalarizes tuple locals and parameters, and then +//! rebuilds execution-graph metadata. +//! +//! Several passes reuse [`cloner::FirCloner`] for deep-cloning FIR subtrees, +//! while others rewrite nodes in place or rebuild derived structures from +//! scratch. +//! +//! # Cross-pass contracts +//! +//! - **Single [`Assigner`] continuity.** The pipeline constructs one +//! [`Assigner`] from the input package and threads it through every pass +//! (`monomorphize`, `return_unify`, `defunctionalize`, `udt_erase`, +//! `tuple_compare_lower`, `sroa`, `arg_promote`). Each pass allocates fresh +//! IDs against that shared counter so synthesized nodes from earlier stages +//! stay disjoint from IDs allocated later. Passes must not construct a +//! fresh [`Assigner`] mid-pipeline. +//! - **`EMPTY_EXEC_RANGE` sentinel.** Passes that synthesize new +//! [`fir::Expr`](qsc_fir::fir::Expr) or [`fir::Stmt`](qsc_fir::fir::Stmt) +//! nodes attach `EMPTY_EXEC_RANGE` as their `exec_graph_range`. The final +//! [`exec_graph_rebuild`] pass consumes that sentinel and repopulates the +//! execution graph from the rewritten FIR. + +use miette::Diagnostic; +use qsc_fir::assigner::Assigner; +use qsc_fir::fir::{ExecGraphIdx, PackageId, PackageStore, StoreItemId}; +use thiserror::Error; + +/// An empty execution graph range for synthesized FIR nodes that do not +/// participate in the execution graph. +pub(crate) const EMPTY_EXEC_RANGE: std::ops::Range = std::ops::Range { + start: ExecGraphIdx::ZERO, + end: ExecGraphIdx::ZERO, +}; + +/// Errors produced by the FIR transform pipeline. +/// +/// Wraps pass-specific error types so callers handle a single diagnostic +/// type from [`run_pipeline`] and [`run_pipeline_to`]. +#[derive(Clone, Debug, Diagnostic, Error)] +pub enum PipelineError { + /// A return-unification error (e.g., unsupported return type inside a loop). + #[error(transparent)] + #[diagnostic(transparent)] + ReturnUnify(#[from] return_unify::Error), + + /// A defunctionalization error (e.g., dynamic callable, convergence failure). + #[error(transparent)] + #[diagnostic(transparent)] + Defunctionalize(#[from] defunctionalize::Error), +} + +/// How far through the FIR transform schedule to run. +/// +/// `Sroa`, `ArgPromote`, and `ExecGraphRebuild` are mainly used by tests and +/// internal validation helpers. +/// Production uses `Full`. +#[doc(hidden)] +#[cfg_attr(not(test), allow(dead_code))] +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum PipelineStage { + /// Run through monomorphization. + Mono, + /// Run through return unification. + ReturnUnify, + /// Run through defunctionalization. + Defunc, + /// Run through UDT erasure. + UdtErase, + /// Run through tuple comparison lowering. + TupleCompLower, + /// Run through SROA. + Sroa, + /// Run through argument promotion. + ArgPromote, + /// Run through unreachable-node garbage collection. + Gc, + /// Run through item-level dead code elimination. + ItemDce, + /// Run through exec graph rebuild. + ExecGraphRebuild, + /// Run the full pipeline. + Full, +} + +pub mod cloner; +pub(crate) mod fir_builder; +pub mod invariants; +pub mod pretty; +pub mod reachability; + +pub mod arg_promote; +pub mod defunctionalize; +pub mod exec_graph_rebuild; +pub mod gc_unreachable; +pub mod item_dce; +pub mod monomorphize; +pub mod return_unify; +pub mod sroa; +pub mod tuple_compare_lower; +pub mod udt_erase; + +#[cfg(any(test, feature = "testutil"))] +pub mod test_utils; + +pub(crate) mod walk_utils; + +/// Runs the FIR transform schedule up to `stage`, threading a single +/// [`Assigner`] through every pass. +/// +/// The [`Assigner`] is constructed once from the input package and passed by +/// mutable reference to each pass so ID allocations from earlier stages are +/// observed by later stages. Between major stages the function invokes +/// [`invariants::check`] with the corresponding [`invariants::InvariantLevel`]. +/// +/// If [`return_unify::unify_returns`] or +/// [`defunctionalize::defunctionalize`] reports any diagnostics the function +/// returns them immediately, skipping subsequent passes and invariant checks. +/// The intermediate FIR at that point intentionally violates downstream +/// invariants, so running later passes would produce misleading failures. +/// Test helpers rely on this early-return to inspect the errors before +/// later invariant checks or downstream passes fail on the intentionally +/// invalid intermediate FIR. +fn run_pipeline_to_impl( + store: &mut PackageStore, + package_id: PackageId, + stage: PipelineStage, + pinned_items: &[StoreItemId], +) -> Vec { + assert!( + store.get(package_id).entry.is_some(), + "FIR transform pipeline requires a package with an entry expression; \ + library packages should not be passed to the transform pipeline" + ); + let mut assigner = Assigner::from_package(store.get(package_id)); + + monomorphize::monomorphize(store, package_id, &mut assigner); + invariants::check(store, package_id, invariants::InvariantLevel::PostMono); + if matches!(stage, PipelineStage::Mono) { + return Vec::new(); + } + + let ru_errors = return_unify::unify_returns(store, package_id, &mut assigner); + if !ru_errors.is_empty() { + return ru_errors.into_iter().map(PipelineError::from).collect(); + } + invariants::check( + store, + package_id, + invariants::InvariantLevel::PostReturnUnify, + ); + if matches!(stage, PipelineStage::ReturnUnify) { + return Vec::new(); + } + + let errors = defunctionalize::defunctionalize(store, package_id, &mut assigner); + if !errors.is_empty() { + return errors.into_iter().map(PipelineError::from).collect(); + } + + invariants::check(store, package_id, invariants::InvariantLevel::PostDefunc); + if matches!(stage, PipelineStage::Defunc) { + return Vec::new(); + } + + udt_erase::erase_udts(store, package_id, &mut assigner); + invariants::check(store, package_id, invariants::InvariantLevel::PostUdtErase); + if matches!(stage, PipelineStage::UdtErase) { + return Vec::new(); + } + + tuple_compare_lower::lower_tuple_comparisons(store, package_id, &mut assigner); + invariants::check( + store, + package_id, + invariants::InvariantLevel::PostTupleCompLower, + ); + if matches!(stage, PipelineStage::TupleCompLower) { + return Vec::new(); + } + + sroa::sroa(store, package_id, &mut assigner); + invariants::check(store, package_id, invariants::InvariantLevel::PostSroa); + if matches!(stage, PipelineStage::Sroa) { + return Vec::new(); + } + + arg_promote::arg_promote(store, package_id, &mut assigner); + invariants::check( + store, + package_id, + invariants::InvariantLevel::PostArgPromote, + ); + if matches!(stage, PipelineStage::ArgPromote) { + return Vec::new(); + } + + gc_unreachable::gc_unreachable(store.get_mut(package_id)); + invariants::check(store, package_id, invariants::InvariantLevel::PostGc); + if matches!(stage, PipelineStage::Gc) { + return Vec::new(); + } + + // Item DCE: remove unreachable callable items and dead type items. + // Callers may pin items via `pinned_items` to keep them (and their + // transitive dependencies) alive through DCE and exec-graph-rebuild. + run_item_dce_and_gc(store, package_id, pinned_items); + if matches!(stage, PipelineStage::ItemDce) { + return Vec::new(); + } + + exec_graph_rebuild::rebuild_exec_graphs(store, package_id, pinned_items); + if matches!(stage, PipelineStage::ExecGraphRebuild) { + return Vec::new(); + } + + // PostAll uses entry-only reachability. Pinned items (original target kept + // for fir_to_qir_from_callable) retain pre-transform types and are not checked. + invariants::check(store, package_id, invariants::InvariantLevel::PostAll); + Vec::new() +} + +/// Runs item-level DCE with optional pinned-root expansion, followed by +/// conditional GC if any items were removed. +/// +/// Pinned items are NOT invariant-checked — `PostAll` uses entry-only +/// reachability. Pinning is needed when the original target ID is used +/// by `fir_to_qir_from_callable` after defunc rewrites the entry `Call` +/// to reference the specialized callable. +fn run_item_dce_and_gc( + store: &mut PackageStore, + package_id: PackageId, + pinned_items: &[StoreItemId], +) { + let reachable = if pinned_items.is_empty() { + reachability::collect_reachable_from_entry(store, package_id) + } else { + reachability::collect_reachable_with_seeds(store, package_id, pinned_items) + }; + let removed = item_dce::eliminate_dead_items(package_id, store.get_mut(package_id), &reachable); + if removed > 0 { + gc_unreachable::gc_unreachable(store.get_mut(package_id)); + } +} + +/// Runs the authoritative FIR optimization schedule up to the requested stage. +/// +/// Production uses `PipelineStage::Full`. Intermediate cut points exist so +/// crate tests can reuse the real production ordering without re-implementing +/// it in helper code. +#[doc(hidden)] +pub fn run_pipeline_to( + store: &mut PackageStore, + package_id: PackageId, + stage: PipelineStage, + pinned_items: &[StoreItemId], +) -> Vec { + run_pipeline_to_impl(store, package_id, stage, pinned_items) +} + +/// Runs the full FIR optimization pipeline on the given package. +/// +/// The pipeline applies the following passes in order: +/// - Monomorphization: eliminates generic callables +/// - Return unification: rewrites callable bodies to a single-exit form +/// - Defunctionalization: eliminates callable-valued expressions +/// - UDT erasure: replaces `Ty::Udt` with pure tuple or scalar types +/// - Tuple comparison lowering: rewrites `BinOp(Eq/Neq)` on non-empty tuple +/// operands into element-wise scalar comparisons +/// - SROA (iterative): decomposes tuple-typed locals into scalars +/// - Argument promotion (iterative): decomposes tuple-typed callable +/// parameters into scalars +/// - GC unreachable: tombstones orphaned arena nodes +/// - Item DCE: removes unreachable items from the item map, then re-runs +/// GC to tombstone orphaned `StmtKind::Item` stmts +/// - Exec graph rebuild: recomputes exec graph ranges after synthesized FIR +/// nodes are introduced +/// +/// Invariant checks are inserted between the major structural stages and after +/// the final rebuild to catch structural violations early. +/// +/// Returns any errors produced by the transform pipeline. An empty vector +/// indicates success. +pub fn run_pipeline(store: &mut PackageStore, package_id: PackageId) -> Vec { + run_pipeline_to(store, package_id, PipelineStage::Full, &[]) +} diff --git a/source/compiler/qsc_fir_transforms/src/monomorphize.rs b/source/compiler/qsc_fir_transforms/src/monomorphize.rs new file mode 100644 index 0000000000..f5634f6fde --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/monomorphize.rs @@ -0,0 +1,928 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! Monomorphization pass. +//! +//! Eliminates all generic callable references in entry-reachable code by +//! creating concrete specializations for each unique `(callable, generic_args)` +//! pair and rewriting call sites to use those specializations. +//! +//! Establishes [`crate::invariants::InvariantLevel::PostMono`]: no `Ty::Param` +//! remains in reachable code and every `ExprKind::Var` node carries an empty +//! generic-argument list. +//! +//! The algorithm operates in three phases. Discovery walks all entry-reachable +//! code collecting every concrete generic reference. Specialization processes +//! these references via a worklist: for each `(callable, args)` pair it clones +//! the callable body, substitutes type parameters with concrete types, and +//! scans the result for transitive generic references that are fed back into +//! the worklist. Rewrite then redirects all call sites to the newly created +//! specialized callables. +//! +//! # Input patterns +//! +//! - `ExprKind::Var(Res::Item(id), [GenericArg::Ty(Int)])` — a generic call +//! site whose arguments are fully concrete. +//! - `CallableDecl` with non-empty `generics` — a generic callable that will +//! be cloned once per distinct concrete instantiation. +//! +//! # Rewrites +//! +//! Given `function Identity<'T>(x : 'T) : 'T { x }` invoked as `Identity(42)`: +//! +//! ```text +//! // Before +//! Call(Var(Identity, [Ty(Int)]), 42) +//! +//! // After +//! Call(Var(Identity, []), 42) +//! ``` +//! +//! A new `Identity` callable is inserted into the target package with +//! all `Ty::Param` nodes substituted for `Int`, and the call site loses its +//! generic-argument list. +//! +//! # Notes +//! +//! - Identity instantiations (`[Param(0), Param(1), ...]`) are skipped; they +//! would produce a duplicate identical to the original generic callable. +//! - Intrinsics whose call sites use concrete generic arguments have their +//! argument lists cleared in place (no new callable is synthesized). +//! - Cross-package references are cloned into the target package so the +//! specialized bodies are self-contained. + +#[cfg(test)] +mod tests; + +#[cfg(all(test, feature = "slow-proptest-tests"))] +mod semantic_equivalence_tests; + +use crate::cloner::FirCloner; +use crate::fir_builder::{functored_specs, reachable_local_callables}; +use crate::reachability::collect_reachable_from_entry; +use crate::walk_utils::{ + collect_expr_ids_in_entry_and_local_callables, extend_expr_ids_in_local_callables, +}; +use qsc_fir::assigner::Assigner; +use qsc_fir::fir::{ + BlockId, CallableDecl, CallableImpl, ExprId, ExprKind, Ident, Item, ItemId, ItemKind, + LocalItemId, LocalVarId, Package, PackageId, PackageLookup, PackageStore, PatId, PatKind, Res, + StmtId, StmtKind, StoreItemId, Visibility, +}; +use qsc_fir::ty::{Arrow, FunctorSet, GenericArg, ParamId, Ty}; +use rustc_hash::{FxHashMap, FxHashSet}; +use std::collections::VecDeque; +use std::rc::Rc; + +/// A recorded specialization: the source callable + args, and where it was +/// placed in the target package. +struct Specialization { + source: StoreItemId, + args: Vec, + new_item_id: ItemId, +} + +/// Monomorphizes all generic callable references in the entry-reachable portion +/// of a package. +/// +/// After this pass, no `Ty::Param` or `FunctorSet::Param` values remain in +/// reachable code, and all `ExprKind::Var` nodes have empty generic-argument +/// lists. +/// +/// Returns immediately without modification if the package has no entry +/// expression. +pub fn monomorphize(store: &mut PackageStore, package_id: PackageId, assigner: &mut Assigner) { + let package = store.get(package_id); + assert!( + package.entry.is_some(), + "monomorphize requires a package entry expression" + ); + + let instantiations = discover_instantiations(store, package_id); + if instantiations.is_empty() { + return; + } + + // Take ownership of the assigner for the duration of specialization + // and restore it afterward with advanced counters. + let owned_assigner = std::mem::take(assigner); + + // Create specialized (monomorphized) callables. + let (specializations, returned_assigner) = + create_specializations(store, package_id, instantiations, owned_assigner); + *assigner = returned_assigner; + + let expr_ids = collect_rewrite_scope(store, package_id, &specializations); + + let package = store.get_mut(package_id); + rewrite_call_sites(package, package_id, &specializations, &expr_ids); +} + +/// Collects all expression IDs that may contain generic call sites requiring +/// rewriting: entry-reachable callables, newly created specializations, and +/// any closure items transitively referenced by those specializations. +fn collect_rewrite_scope( + store: &PackageStore, + package_id: PackageId, + specializations: &[Specialization], +) -> Vec { + let reachable = collect_reachable_from_entry(store, package_id); + let package = store.get(package_id); + let local_item_ids: Vec<_> = reachable_local_callables(package, package_id, &reachable) + .map(|(id, _)| id) + .collect(); + let mut expr_ids = collect_expr_ids_in_entry_and_local_callables(package, &local_item_ids); + let new_item_ids: Vec<_> = specializations.iter().map(|s| s.new_item_id.item).collect(); + let mut seen: FxHashSet = expr_ids.iter().copied().collect(); + + // We computed reachability after creating specializations but before + // rewriting call sites, so new specializations aren't reachable from + // entry yet. Those new specializations may reference newly-cloned + // closure items that are also unreachable from entry until call sites + // are redirected. + let mut walked_items: FxHashSet = local_item_ids.iter().copied().collect(); + walked_items.extend(new_item_ids.iter().copied()); + + let mut scan_start = expr_ids.len(); + extend_expr_ids_in_local_callables(package, &new_item_ids, &mut expr_ids, &mut seen); + + // Transitively walk closure items whose bodies may also contain generic + // call sites that need rewriting. + loop { + let mut new_closures = Vec::new(); + for &expr_id in &expr_ids[scan_start..] { + if let ExprKind::Closure(_, local_item_id) = &package.get_expr(expr_id).kind + && walked_items.insert(*local_item_id) + { + new_closures.push(*local_item_id); + } + } + if new_closures.is_empty() { + break; + } + scan_start = expr_ids.len(); + extend_expr_ids_in_local_callables(package, &new_closures, &mut expr_ids, &mut seen); + } + + expr_ids +} + +/// Walks all entry-reachable code and collects every unique +/// `(StoreItemId, Vec)` pair where the generic args are non-empty +/// and fully concrete. +fn discover_instantiations( + store: &PackageStore, + package_id: PackageId, +) -> Vec<(StoreItemId, Vec)> { + let reachable = collect_reachable_from_entry(store, package_id); + let mut found: Vec<(StoreItemId, Vec)> = Vec::new(); + let mut seen_keys: FxHashSet = FxHashSet::default(); + + let package = store.get(package_id); + + // Walk the entry expression. + if let Some(entry_id) = package.entry { + collect_generic_refs_in_expr(package, entry_id, &mut found, &mut seen_keys); + } + + // Walk every reachable callable body. + for item_id in &reachable { + let pkg = store.get(item_id.package); + let Some(item) = pkg.items.get(item_id.item) else { + // Interpreter entry expressions can carry runtime-unbound item references + // after a rejected callable definition. Leave those for later evaluation + // diagnostics instead of panicking during reachability discovery. + continue; + }; + if let ItemKind::Callable(decl) = &item.kind { + collect_generic_refs_in_callable(pkg, decl, &mut found, &mut seen_keys); + } + } + + found.retain(|(_, args)| is_fully_concrete(args)); + + found +} + +/// Deterministic dedup key for a `(StoreItemId, &[GenericArg])` pair. +fn mono_key(source: StoreItemId, args: &[GenericArg]) -> String { + use std::fmt::Write; + let mut key = format!("{source}:"); + for (i, arg) in args.iter().enumerate() { + if i > 0 { + key.push(','); + } + write!(key, "{arg}").expect("formatting should not fail"); + } + key +} + +/// Builds a unique mangled name for a monomorphized callable by appending the +/// concrete generic arguments to the base name using `` notation. +/// +/// Functor set arguments use compact identifiers (`Empty`, `Adj`, `Ctl`, +/// `AdjCtl`) instead of the user-facing display forms. The intrinsic callable +/// (`CallableImpl::Intrinsic`) `Length` is exempt because downstream passes +/// match on that name literally. +fn mono_name(decl: &CallableDecl, args: &[GenericArg]) -> Rc { + use std::fmt::Write; + if matches!(decl.implementation, CallableImpl::Intrinsic) && decl.name.name.as_ref() == "Length" + { + return Rc::clone(&decl.name.name); + } + let mut name = decl.name.name.to_string(); + name.push('<'); + for (i, arg) in args.iter().enumerate() { + if i > 0 { + name.push_str(", "); + } + match arg { + GenericArg::Ty(ty) => write!(name, "{ty}").expect("formatting should not fail"), + GenericArg::Functor(FunctorSet::Value(v)) => name.push_str(v.mangle_name()), + GenericArg::Functor(f) => write!(name, "{f}").expect("formatting should not fail"), + } + } + name.push('>'); + Rc::from(name.as_str()) +} + +/// Walks a callable's body collecting every `(StoreItemId, Vec)` +/// pair referenced by `ExprKind::Var(Res::Item(..), args)` with non-empty +/// generic arguments, deduplicated via `mono_key` in `seen`. +fn collect_generic_refs_in_callable( + pkg: &Package, + decl: &CallableDecl, + found: &mut Vec<(StoreItemId, Vec)>, + seen: &mut FxHashSet, +) { + crate::walk_utils::for_each_expr_in_callable_impl( + pkg, + &decl.implementation, + &mut |_eid, expr| { + if let ExprKind::Var(Res::Item(item_id), generic_args) = &expr.kind + && !generic_args.is_empty() + { + let store_id = StoreItemId::from((item_id.package, item_id.item)); + let key = mono_key(store_id, generic_args); + if seen.insert(key) { + found.push((store_id, generic_args.clone())); + } + } + }, + ); +} + +/// Walks a single expression subtree collecting `(StoreItemId, Vec)` +/// pairs the same way as [`collect_generic_refs_in_callable`], used for the +/// package entry expression. +fn collect_generic_refs_in_expr( + pkg: &Package, + expr_id: ExprId, + found: &mut Vec<(StoreItemId, Vec)>, + seen: &mut FxHashSet, +) { + crate::walk_utils::for_each_expr(pkg, expr_id, &mut |_eid, expr| { + if let ExprKind::Var(Res::Item(item_id), generic_args) = &expr.kind + && !generic_args.is_empty() + { + let store_id = StoreItemId::from((item_id.package, item_id.item)); + let key = mono_key(store_id, generic_args); + if seen.insert(key) { + found.push((store_id, generic_args.clone())); + } + } + }); +} + +/// Returns `true` when all generic args map to their own parameter position — +/// e.g., `[Param(0), Param(1)]` for a 2-parameter callable. Cloning with such +/// args would produce a useless duplicate identical to the original generic. +fn is_identity_instantiation(args: &[GenericArg]) -> bool { + args.iter().enumerate().all(|(i, arg)| match arg { + GenericArg::Ty(Ty::Param(p)) | GenericArg::Functor(FunctorSet::Param(p)) => { + *p == ParamId::from(i) + } + _ => false, + }) +} + +/// Returns `true` when no `Ty::Param` or `FunctorSet::Param` appears at any +/// depth inside the given generic args. +fn is_fully_concrete(args: &[GenericArg]) -> bool { + args.iter().all(|arg| match arg { + GenericArg::Ty(ty) => !ty_contains_param(ty), + GenericArg::Functor(FunctorSet::Param(_)) => false, + GenericArg::Functor(_) => true, + }) +} + +/// Returns `true` when a `Ty` contains a `Ty::Param` or `FunctorSet::Param` +/// anywhere in its structure. +fn ty_contains_param(ty: &Ty) -> bool { + match ty { + Ty::Param(_) => true, + Ty::Array(inner) => ty_contains_param(inner), + Ty::Arrow(arrow) => { + ty_contains_param(&arrow.input) + || ty_contains_param(&arrow.output) + || matches!(arrow.functors, FunctorSet::Param(_)) + } + Ty::Tuple(items) => items.iter().any(ty_contains_param), + _ => false, + } +} + +/// Walks a cloned callable body and collects every +/// `ExprKind::Var(Res::Item(id), args)` where `args` is non-empty and fully +/// concrete (no remaining `Ty::Param` or `FunctorSet::Param`). +fn scan_for_concrete_generic_refs( + pkg: &Package, + decl: &CallableDecl, +) -> Vec<(StoreItemId, Vec)> { + let mut found = Vec::new(); + let mut seen = FxHashSet::default(); + collect_generic_refs_in_callable(pkg, decl, &mut found, &mut seen); + found.retain(|(_, args)| is_fully_concrete(args)); + found +} + +#[allow(clippy::too_many_lines)] +/// Drives the worklist that clones each requested `(callable, args)` pair +/// into the target package, substitutes type parameters, and scans the +/// cloned bodies for additional transitively-referenced generic sites. +/// +/// Returns the inserted specializations plus the assigner so its counter +/// can be threaded back into the pipeline. +fn create_specializations( + store: &mut PackageStore, + target_pkg_id: PackageId, + instantiations: Vec<(StoreItemId, Vec)>, + assigner: Assigner, +) -> (Vec, Assigner) { + let mut specializations = Vec::new(); + + // Pre-populate seen keys from initial discovery. + let mut seen_keys: FxHashSet = instantiations + .iter() + .map(|(source, args)| mono_key(*source, args)) + .collect(); + let mut worklist: VecDeque<(StoreItemId, Vec)> = instantiations.into(); + + // Temporarily take the target package out of the store so we can hold + // `&source_pkg` (for cross-package) and `&mut target_pkg` simultaneously. + let empty_pkg = empty_package(); + let mut target_pkg = std::mem::replace(store.get_mut(target_pkg_id), empty_pkg); + + let mut cloner = FirCloner::from_assigner(assigner); + + while let Some((source_id, args)) = worklist.pop_front() { + // Skip identity instantiations — cloning with these produces a + // useless duplicate identical to the original generic callable. + if is_identity_instantiation(&args) { + continue; + } + + // Extract needed data from the source package (read-only). + let (body_pkg, decl_snapshot) = { + let source_pkg: &Package = if source_id.package == target_pkg_id { + &target_pkg + } else { + store.get(source_id.package) + }; + let source_item = source_pkg.get_item(source_id.item); + let source_decl = match &source_item.kind { + ItemKind::Callable(decl) => decl.as_ref(), + _ => continue, + }; + let body_pkg = extract_callable_body(source_pkg, source_decl); + let decl_snapshot = source_decl.clone(); + (body_pkg, decl_snapshot) + }; // source_pkg borrow released + + // Clone body into target, substitute types, insert (mutate). + let new_local_id = cloner.alloc_item(); + let new_item_id = ItemId { + package: target_pkg_id, + item: new_local_id, + }; + let old_item_id = ItemId { + package: source_id.package, + item: source_id.item, + }; + + // Reserve the item slot so that clone_nested_item (called during + // clone_callable_impl for StmtKind::Item / ExprKind::Closure) does + // not allocate the same LocalItemId for a nested item. + target_pkg.items.insert( + new_local_id, + Item { + id: new_local_id, + span: decl_snapshot.span, + parent: None, + doc: Rc::from(""), + attrs: vec![], + visibility: Visibility::Public, + kind: ItemKind::Namespace( + Ident { + id: LocalVarId::default(), + span: decl_snapshot.name.span, + name: Rc::from(""), + }, + vec![], + ), + }, + ); + + cloner.reset_maps(); + cloner.set_self_item_remap(old_item_id, new_item_id); + + // Clone input BEFORE impl so that `local_map` contains input + // parameter mappings when the callable body is walked. + let new_input = cloner.clone_input_pat(&body_pkg, decl_snapshot.input, &mut target_pkg); + let new_impl = + cloner.clone_callable_impl(&body_pkg, &decl_snapshot.implementation, &mut target_pkg); + let new_node_id = cloner.next_node(); + + // Substitute Ty::Param / FunctorSet::Param in all cloned nodes. + let arg_map = build_arg_map(&args); + substitute_types_in_cloned_nodes(&mut target_pkg, &cloner, &arg_map); + + let output = substitute_ty(&decl_snapshot.output, &arg_map); + + let spec_name = mono_name(&decl_snapshot, &args); + let spec_decl = CallableDecl { + id: new_node_id, + span: decl_snapshot.span, + kind: decl_snapshot.kind, + name: Ident { + id: LocalVarId::default(), + span: decl_snapshot.name.span, + name: spec_name, + }, + generics: vec![], + input: new_input, + output, + functors: decl_snapshot.functors, + implementation: new_impl, + attrs: decl_snapshot.attrs.clone(), + }; + + let new_item = Item { + id: new_local_id, + span: decl_snapshot.span, + parent: None, + doc: Rc::from(""), + attrs: vec![], + visibility: Visibility::Public, + kind: ItemKind::Callable(Box::new(spec_decl)), + }; + target_pkg.items.insert(new_local_id, new_item); + + // Scan the newly created callable for additional concrete + // generic references that need their own specializations. Skip + // references to items in the target package that are already + // non-generic (e.g., self-references from recursive callables that + // were remapped by set_self_item_remap). + let created_item = target_pkg.items.get(new_local_id).expect("just inserted"); + if let ItemKind::Callable(created_decl) = &created_item.kind { + let new_refs = scan_for_concrete_generic_refs(&target_pkg, created_decl); + for (ref_id, ref_args) in new_refs { + if ref_id.package == target_pkg_id + && let Some(ref_item) = target_pkg.items.get(ref_id.item) + && let ItemKind::Callable(ref_decl) = &ref_item.kind + && ref_decl.generics.is_empty() + { + continue; + } + let key = mono_key(ref_id, &ref_args); + if seen_keys.insert(key) { + worklist.push_back((ref_id, ref_args)); + } + } + } + + specializations.push(Specialization { + source: source_id, + args, + new_item_id, + }); + } + + // Put the target package back. + *store.get_mut(target_pkg_id) = target_pkg; + + (specializations, cloner.into_assigner()) +} + +/// Constructs an empty `Package` used as a scratch container for body +/// extraction and for temporarily swapping out the target package during +/// specialization. +fn empty_package() -> Package { + Package::default() +} + +/// Builds a standalone `Package` holding all nodes transitively referenced +/// by a callable's body so that [`FirCloner`] can read from it without +/// holding a reference to the original source package. +fn extract_callable_body(source_pkg: &Package, decl: &CallableDecl) -> Package { + let mut body_pkg = empty_package(); + + // Input pattern. + extract_pat(source_pkg, decl.input, &mut body_pkg); + + match &decl.implementation { + CallableImpl::Intrinsic => {} + CallableImpl::Spec(spec_impl) => { + extract_spec_decl_body(source_pkg, &spec_impl.body, &mut body_pkg); + for spec in functored_specs(spec_impl) { + extract_spec_decl_body(source_pkg, spec, &mut body_pkg); + } + } + CallableImpl::SimulatableIntrinsic(spec) => { + extract_spec_decl_body(source_pkg, spec, &mut body_pkg); + } + } + + body_pkg +} + +/// Copies the input pattern and body block of a `SpecDecl` from `source` into +/// `target`. +fn extract_spec_decl_body(source: &Package, spec: &qsc_fir::fir::SpecDecl, target: &mut Package) { + if let Some(pat_id) = spec.input { + extract_pat(source, pat_id, target); + } + extract_block(source, spec.block, target); +} + +/// Recursively copies a block and all statements it references. +fn extract_block(source: &Package, block_id: BlockId, target: &mut Package) { + if target.blocks.contains_key(block_id) { + return; + } + let block = source.get_block(block_id); + target.blocks.insert(block_id, block.clone()); + for &stmt_id in &block.stmts { + extract_stmt(source, stmt_id, target); + } +} + +/// Recursively copies a statement and any patterns, expressions, or items it +/// references. +fn extract_stmt(source: &Package, stmt_id: StmtId, target: &mut Package) { + if target.stmts.contains_key(stmt_id) { + return; + } + let stmt = source.get_stmt(stmt_id); + target.stmts.insert(stmt_id, stmt.clone()); + match &stmt.kind { + StmtKind::Expr(e) | StmtKind::Semi(e) => extract_expr(source, *e, target), + StmtKind::Local(_, pat, expr) => { + extract_pat(source, *pat, target); + extract_expr(source, *expr, target); + } + StmtKind::Item(item_id) => { + extract_item(source, *item_id, target); + } + } +} + +/// Recursively copies an expression and its transitive references. +/// +/// NOTE: This is intentionally a separate implementation from the nearly +/// identical `extract_expr` in `defunctionalize/specialize.rs`. The key +/// difference is the `ExprKind::Closure` arm: monomorphize follows the +/// closure's lifted item via [`extract_item`] because type substitution +/// (`Ty::Param` → concrete) must be applied to the lambda body when a +/// generic callable is monomorphized. Without extracting the item, +/// `substitute_types_in_cloned_nodes` would miss it. +fn extract_expr(source: &Package, expr_id: ExprId, target: &mut Package) { + if target.exprs.contains_key(expr_id) { + return; + } + let expr = source.get_expr(expr_id); + target.exprs.insert(expr_id, expr.clone()); + match &expr.kind { + ExprKind::Array(es) | ExprKind::ArrayLit(es) | ExprKind::Tuple(es) => { + for &e in es { + extract_expr(source, e, target); + } + } + ExprKind::ArrayRepeat(a, b) + | ExprKind::Assign(a, b) + | ExprKind::AssignOp(_, a, b) + | ExprKind::BinOp(_, a, b) + | ExprKind::Call(a, b) + | ExprKind::Index(a, b) + | ExprKind::AssignField(a, _, b) + | ExprKind::UpdateField(a, _, b) => { + extract_expr(source, *a, target); + extract_expr(source, *b, target); + } + ExprKind::AssignIndex(a, b, c) | ExprKind::UpdateIndex(a, b, c) => { + extract_expr(source, *a, target); + extract_expr(source, *b, target); + extract_expr(source, *c, target); + } + ExprKind::Block(block_id) => extract_block(source, *block_id, target), + ExprKind::Fail(e) | ExprKind::Field(e, _) | ExprKind::Return(e) | ExprKind::UnOp(_, e) => { + extract_expr(source, *e, target); + } + ExprKind::If(cond, body, otherwise) => { + extract_expr(source, *cond, target); + extract_expr(source, *body, target); + if let Some(e) = otherwise { + extract_expr(source, *e, target); + } + } + ExprKind::Range(s, st, e) => { + for x in [s, st, e].into_iter().flatten() { + extract_expr(source, *x, target); + } + } + ExprKind::Struct(_, copy, fields) => { + if let Some(c) = copy { + extract_expr(source, *c, target); + } + for fa in fields { + extract_expr(source, fa.value, target); + } + } + ExprKind::String(components) => { + for c in components { + if let qsc_fir::fir::StringComponent::Expr(e) = c { + extract_expr(source, *e, target); + } + } + } + ExprKind::While(cond, block) => { + extract_expr(source, *cond, target); + extract_block(source, *block, target); + } + ExprKind::Closure(_, local_item_id) => { + extract_item(source, *local_item_id, target); + } + ExprKind::Hole | ExprKind::Lit(_) | ExprKind::Var(_, _) => {} + } +} + +/// Recursively copies a local item (callable, namespace, or UDT) and every +/// body node it references so nested items referenced via `StmtKind::Item` +/// or `ExprKind::Closure` remain resolvable. +fn extract_item(source: &Package, item_id: LocalItemId, target: &mut Package) { + if target.items.contains_key(item_id) { + return; + } + let item = source.get_item(item_id); + target.items.insert(item_id, item.clone()); + if let ItemKind::Callable(decl) = &item.kind { + // Extract all nodes transitively referenced by this callable into + // the target body package. + extract_pat(source, decl.input, target); + match &decl.implementation { + CallableImpl::Intrinsic => {} + CallableImpl::Spec(spec_impl) => { + extract_spec_decl_body(source, &spec_impl.body, target); + for spec in functored_specs(spec_impl) { + extract_spec_decl_body(source, spec, target); + } + } + CallableImpl::SimulatableIntrinsic(spec) => { + extract_spec_decl_body(source, spec, target); + } + } + } +} + +/// Recursively copies a pattern and its sub-patterns (for tuple patterns). +fn extract_pat(source: &Package, pat_id: PatId, target: &mut Package) { + if target.pats.contains_key(pat_id) { + return; + } + let pat = source.get_pat(pat_id); + target.pats.insert(pat_id, pat.clone()); + if let PatKind::Tuple(sub_pats) = &pat.kind { + for &p in sub_pats { + extract_pat(source, p, target); + } + } +} + +/// Builds a `ParamId → GenericArg` map by pairing positional arguments with +/// their index as the parameter identifier. +fn build_arg_map(args: &[GenericArg]) -> FxHashMap { + args.iter() + .enumerate() + .map(|(ix, arg)| (ParamId::from(ix), arg.clone())) + .collect() +} + +/// Replaces every `Ty::Param` in `ty` with its mapped concrete type. +fn substitute_ty(ty: &Ty, arg_map: &FxHashMap) -> Ty { + match ty { + Ty::Param(param) => match arg_map.get(param) { + Some(GenericArg::Ty(concrete)) => concrete.clone(), + _ => ty.clone(), + }, + Ty::Array(inner) => Ty::Array(Box::new(substitute_ty(inner, arg_map))), + Ty::Arrow(arrow) => Ty::Arrow(Box::new(substitute_arrow(arrow, arg_map))), + Ty::Tuple(items) => Ty::Tuple(items.iter().map(|t| substitute_ty(t, arg_map)).collect()), + Ty::Prim(_) | Ty::Udt(_) | Ty::Infer(_) | Ty::Err => ty.clone(), + } +} + +/// Applies [`substitute_ty`] and [`substitute_functor_set`] to each field of +/// an `Arrow` type. +fn substitute_arrow(arrow: &Arrow, arg_map: &FxHashMap) -> Arrow { + Arrow { + kind: arrow.kind, + input: Box::new(substitute_ty(&arrow.input, arg_map)), + output: Box::new(substitute_ty(&arrow.output, arg_map)), + functors: substitute_functor_set(arrow.functors, arg_map), + } +} + +/// Replaces a `FunctorSet::Param` with its mapped concrete functor set. +fn substitute_functor_set( + functors: FunctorSet, + arg_map: &FxHashMap, +) -> FunctorSet { + match functors { + FunctorSet::Param(param) => match arg_map.get(¶m) { + Some(GenericArg::Functor(concrete)) => *concrete, + _ => functors, + }, + _ => functors, + } +} + +/// Walks all nodes that the cloner inserted into the target package and +/// replaces `Ty::Param` / `FunctorSet::Param` with concrete types. +/// Also substitutes types inside generic args on `ExprKind::Var` expressions +/// and clears generic args that become concrete after substitution. +/// +/// # Before +/// ```text +/// Expr { ty: Ty::Param(0), kind: Var(item, [Ty(Param(0))]) } +/// Block { ty: Ty::Param(0) } +/// Pat { ty: Ty::Param(0) } +/// ``` +/// # After +/// ```text +/// Expr { ty: Int, kind: Var(item, [Ty(Int)]) } // Param(0) → Int +/// Block { ty: Int } +/// Pat { ty: Int } +/// ``` +/// +/// # Mutations +/// - Rewrites `Expr.ty`, `Block.ty`, and `Pat.ty` for every cloned node. +/// - Substitutes generic args on `ExprKind::Var` expressions. +/// - Substitutes callable declaration output types for nested items. +fn substitute_types_in_cloned_nodes( + target: &mut Package, + cloner: &FirCloner, + arg_map: &FxHashMap, +) { + // Blocks. + for &new_id in cloner.block_map().values() { + if let Some(block) = target.blocks.get_mut(new_id) { + block.ty = substitute_ty(&block.ty, arg_map); + } + } + + // Expressions — substitute types and handle generic args on Var. + for &new_id in cloner.expr_map().values() { + if let Some(expr) = target.exprs.get_mut(new_id) { + expr.ty = substitute_ty(&expr.ty, arg_map); + + // Substitute types within generic args on Var. + if let ExprKind::Var(_, ref mut generic_args) = expr.kind + && !generic_args.is_empty() + { + for ga in generic_args.iter_mut() { + *ga = substitute_generic_arg(ga, arg_map); + } + // Do NOT clear here — rewrite_call_sites needs the + // substituted args to find the monomorphized target. + } + } + } + + // Patterns. + for &new_id in cloner.pat_map().values() { + if let Some(pat) = target.pats.get_mut(new_id) { + pat.ty = substitute_ty(&pat.ty, arg_map); + } + } + + // Nested callable items cloned into a specialization may capture outer + // generic parameters in their signatures even when they do not declare + // generics of their own (for example, lifted lambdas inside generic + // stdlib helpers). Rewrite those declaration-level types as well. + for &new_id in cloner.item_map().values() { + let Some(item) = target.items.get_mut(new_id) else { + continue; + }; + let ItemKind::Callable(decl) = &mut item.kind else { + continue; + }; + if decl.generics.is_empty() { + decl.output = substitute_ty(&decl.output, arg_map); + } + } +} + +/// Substitutes type parameters inside a `GenericArg` (delegating to +/// [`substitute_ty`] or [`substitute_functor_set`]). +fn substitute_generic_arg(ga: &GenericArg, arg_map: &FxHashMap) -> GenericArg { + match ga { + GenericArg::Ty(ty) => GenericArg::Ty(substitute_ty(ty, arg_map)), + GenericArg::Functor(fs) => GenericArg::Functor(substitute_functor_set(*fs, arg_map)), + } +} + +/// Rewrites every generic `Var` call site in the target package to point at +/// the monomorphized callable produced by [`create_specializations`]. +/// +/// # Before +/// ```text +/// Var(Item(generic_callable), [Ty(Int), Functor(Adj)]) +/// ``` +/// # After +/// ```text +/// Var(Item(monomorphized_callable), []) // generic args cleared +/// ``` +/// +/// Residual non-empty generic argument lists on sites whose target has no +/// matching specialization (e.g. intrinsics) are cleared so no `Ty::Param` +/// survives the pass. +/// +/// # Mutations +/// - Rewrites `ExprKind::Var` nodes to reference monomorphized items and +/// clears their generic-argument lists. +fn rewrite_call_sites( + package: &mut Package, + package_id: PackageId, + specializations: &[Specialization], + expr_ids: &[ExprId], +) { + // Build a lookup from (source key) → new ItemId. + let lookup: FxHashMap = specializations + .iter() + .map(|s| (mono_key(s.source, &s.args), s.new_item_id)) + .collect(); + + // Walk scoped expressions and rewrite generic Var references. + for &expr_id in expr_ids { + let expr = package.exprs.get(expr_id).expect("expr should exist"); + if let ExprKind::Var(Res::Item(item_id), ref generic_args) = expr.kind { + if generic_args.is_empty() { + continue; + } + let store_id = StoreItemId::from((item_id.package, item_id.item)); + let key = mono_key(store_id, generic_args); + if let Some(&new_id) = lookup.get(&key) { + let expr_mut = package.exprs.get_mut(expr_id).expect("expr should exist"); + expr_mut.kind = ExprKind::Var(Res::Item(new_id), vec![]); + } else { + // No specialization found — still clear the generic args since + // the types have been substituted already (e.g., intrinsics that + // don't need cloning but whose type params were resolved). + + // Check if this is expected (intrinsic) or a potential bug. + // Only flag when all generic args are concrete — call sites + // inside uninstantiated generic bodies still carry Ty::Param + // references, and those are expected to remain unresolved. + let all_concrete = is_fully_concrete(generic_args); + if all_concrete + && item_id.package == package_id + && let Some(item) = package.items.get(item_id.item) + && let ItemKind::Callable(decl) = &item.kind + { + // Only flag if the target callable actually declares + // type parameters. Call sites pointing at a specialization + // carry an empty generic-arg list; any residual non-empty + // list on a non-specialized target (e.g. an intrinsic) is + // cleared here. + if !decl.generics.is_empty() + && !matches!(decl.implementation, CallableImpl::Intrinsic) + { + panic!( + "Non-intrinsic same-package callable has no monomorphized specialization: \ + item={item_id:?}, args={generic_args:?}" + ); + } + } + + let expr_mut = package.exprs.get_mut(expr_id).expect("expr should exist"); + if let ExprKind::Var(_, ref mut args) = expr_mut.kind { + args.clear(); + } + } + } + } + + // No separate entry-expression rewrite is needed here. The package entry + // is stored as an ExprId in `package.exprs`, whether it came from an + // explicit entry expression or a synthesized `Main()` call. +} diff --git a/source/compiler/qsc_fir_transforms/src/monomorphize/semantic_equivalence_tests.rs b/source/compiler/qsc_fir_transforms/src/monomorphize/semantic_equivalence_tests.rs new file mode 100644 index 0000000000..4c46ddfbcc --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/monomorphize/semantic_equivalence_tests.rs @@ -0,0 +1,64 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use indoc::formatdoc; +use proptest::prelude::*; + +/// Generates syntactically valid Q# programs exercising monomorphization's +/// key code paths: single and multiple type parameters, nested generic calls, +/// and multiple instantiations of the same generic. +fn mono_pattern_strategy() -> impl Strategy { + let val = || 0..50i64; + + prop_oneof![ + // 1. Single type parameter instantiated with Int. + val().prop_map(|a| formatdoc! {" + namespace Test {{ + function Identity<'T>(x : 'T) : 'T {{ x }} + function Main() : Int {{ + Identity({a}) + }} + }} + "}), + // 2. Single type parameter instantiated with Bool. + val().prop_map(|a| formatdoc! {" + namespace Test {{ + function Identity<'T>(x : 'T) : 'T {{ x }} + function IsPositive(n : Int) : Bool {{ n > 0 }} + function Main() : Bool {{ + Identity(IsPositive({a})) + }} + }} + "}), + // 3. Multiple instantiations of the same generic in one program. + (val(), val()).prop_map(|(a, b)| formatdoc! {" + namespace Test {{ + function Identity<'T>(x : 'T) : 'T {{ x }} + function Main() : Int {{ + let x = Identity({a}); + let y = Identity(true); + let z = Identity({b}); + x + z + }} + }} + "}), + // 4. Nested generic calls: generic calling generic. + val().prop_map(|a| formatdoc! {" + namespace Test {{ + function Identity<'T>(x : 'T) : 'T {{ x }} + function Wrap<'T>(x : 'T) : 'T {{ Identity(x) }} + function Main() : Int {{ + Wrap({a}) + }} + }} + "}), + ] +} + +proptest! { + #![proptest_config(ProptestConfig::with_cases(50))] + #[test] + fn differential_monomorphize(source in mono_pattern_strategy()) { + crate::test_utils::check_semantic_equivalence(&source); + } +} diff --git a/source/compiler/qsc_fir_transforms/src/monomorphize/tests.rs b/source/compiler/qsc_fir_transforms/src/monomorphize/tests.rs new file mode 100644 index 0000000000..d6cc60601d --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/monomorphize/tests.rs @@ -0,0 +1,1214 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use super::*; +use expect_test::{Expect, expect}; +use indoc::indoc; +use qsc_fir::assigner::Assigner; +use qsc_fir::fir::NodeId; +use rustc_hash::FxHashSet; + +/// Compiles Q# source, runs monomorphization, and snapshots all callables +/// in the user package showing name, generic-param count, input type, and +/// output type. Sorted for determinism. +fn check(source: &str, expect: &Expect) { + let (mut store, pkg_id) = crate::test_utils::compile_to_fir(source); + let mut assigner = Assigner::from_package(store.get(pkg_id)); + monomorphize(&mut store, pkg_id, &mut assigner); + + let package = store.get(pkg_id); + let mut lines: Vec = Vec::new(); + for (_, item) in &package.items { + if let ItemKind::Callable(decl) = &item.kind { + let pat = package.get_pat(decl.input); + lines.push(format!( + "{}: generics={}, input={}, output={}", + decl.name.name, + decl.generics.len(), + pat.ty, + decl.output, + )); + } + } + lines.sort(); + expect.assert_eq(&lines.join("\n")); +} + +fn check_details(source: &str, expect: &Expect) { + let (store, pkg_id) = crate::test_utils::compile_and_run_pipeline_to( + source, + crate::test_utils::PipelineStage::Mono, + ); + expect.assert_eq(&crate::test_utils::extract_reachable_callable_details( + &store, pkg_id, + )); +} + +/// Compiles Q# source, runs monomorphization, and asserts no +/// `ExprKind::Var` in the user package still carries generic args. +fn assert_no_generic_args(source: &str) { + let (mut store, pkg_id) = crate::test_utils::compile_to_fir(source); + let mut assigner = Assigner::from_package(store.get(pkg_id)); + monomorphize(&mut store, pkg_id, &mut assigner); + + let package = store.get(pkg_id); + for (id, expr) in &package.exprs { + if let ExprKind::Var(_, ref args) = expr.kind { + assert!( + args.is_empty(), + "Expr {id} still has non-empty generic args after monomorphization" + ); + } + } +} + +#[test] +fn mono_explicit_entry_expression_rewritten() { + let (mut store, pkg_id) = crate::test_utils::compile_to_fir_with_entry( + indoc! {r#" + namespace Test { + function Identity<'T>(x : 'T) : 'T { x } + } + "#}, + "Test.Identity(42)", + ); + let mut assigner = Assigner::from_package(store.get(pkg_id)); + monomorphize(&mut store, pkg_id, &mut assigner); + + let package = store.get(pkg_id); + let entry_id = package + .entry + .expect("package should have an entry expression"); + let entry_expr = package.get_expr(entry_id); + let ExprKind::Call(callee_id, _) = entry_expr.kind else { + panic!("entry expression should remain a call") + }; + let callee_expr = package.get_expr(callee_id); + let ExprKind::Var(Res::Item(item_id), ref generic_args) = callee_expr.kind else { + panic!("entry callee should be a callable reference") + }; + + assert!( + generic_args.is_empty(), + "entry-expression callee should not retain generic args after monomorphization" + ); + + let item = package.get_item(item_id.item); + let ItemKind::Callable(decl) = &item.kind else { + panic!("entry callee should resolve to a callable item") + }; + assert_eq!(decl.name.name.as_ref(), "Identity"); +} + +#[test] +fn mono_identity_int() { + check( + indoc! {r#" + operation Identity<'T>(x : 'T) : 'T { x } + operation Main() : Int { Identity(42) } + "#}, + &expect![[r#" + Identity: generics=1, input=Param<0>, output=Param<0> + Identity: generics=0, input=Int, output=Int + Main: generics=0, input=Unit, output=Int"#]], + ); +} + +#[test] +fn mono_identity_qubit() { + check( + indoc! {r#" + operation Identity<'T>(x : 'T) : 'T { x } + operation Main() : Unit { + use q = Qubit(); + let _ = Identity(q); + } + "#}, + &expect![[r#" + Identity: generics=1, input=Param<0>, output=Param<0> + Identity: generics=0, input=Qubit, output=Qubit + Main: generics=0, input=Unit, output=Unit"#]], + ); +} + +#[test] +fn mono_two_instantiations() { + check( + indoc! {r#" + operation Identity<'T>(x : 'T) : 'T { x } + operation Main() : Unit { + let _ = Identity(42); + use q = Qubit(); + let _ = Identity(q); + } + "#}, + &expect![[r#" + Identity: generics=1, input=Param<0>, output=Param<0> + Identity: generics=0, input=Int, output=Int + Identity: generics=0, input=Qubit, output=Qubit + Main: generics=0, input=Unit, output=Unit"#]], + ); +} + +#[test] +fn mono_no_generic_args() { + check( + "operation Main() : Int { 42 }", + &expect!["Main: generics=0, input=Unit, output=Int"], + ); +} + +#[test] +fn mono_multiple_call_sites_same_args() { + // Two call sites with Identity should produce only one + // specialization. + check( + indoc! {r#" + operation Identity<'T>(x : 'T) : 'T { x } + operation Main() : Unit { + let _ = Identity(1); + let _ = Identity(2); + } + "#}, + &expect![[r#" + Identity: generics=1, input=Param<0>, output=Param<0> + Identity: generics=0, input=Int, output=Int + Main: generics=0, input=Unit, output=Unit"#]], + ); +} + +#[test] +fn mono_generic_args_cleared_after_mono() { + assert_no_generic_args(indoc! {r#" + operation Identity<'T>(x : 'T) : 'T { x } + operation Main() : Unit { + let _ = Identity(42); + use q = Qubit(); + let _ = Identity(q); + } + "#}); +} + +#[test] +fn mono_nested_generic_call() { + // Outer<'T> calls Identity<'T> — both should be specialized. + check( + indoc! {r#" + operation Identity<'T>(x : 'T) : 'T { x } + operation Outer<'T>(x : 'T) : 'T { Identity(x) } + operation Main() : Int { Outer(42) } + "#}, + &expect![[r#" + Identity: generics=1, input=Param<0>, output=Param<0> + Identity: generics=0, input=Int, output=Int + Main: generics=0, input=Unit, output=Int + Outer: generics=1, input=Param<0>, output=Param<0> + Outer: generics=0, input=Int, output=Int"#]], + ); +} + +#[test] +fn mono_nested_generic_body_retargets_specialized_callee() { + check_details( + indoc! {r#" + function Inner<'T>(x : 'T) : 'T { x } + function Outer<'T>(x : 'T) : 'T { + let first = Inner(x); + Inner(first) + } + function Main() : Int { Outer(42) } + "#}, + &expect![[r#" + callable Inner: input_ty=Int, output_ty=Int + body: block_ty=Int + [0] Expr ty=Int Var + callable Main: input_ty=Unit, output_ty=Int + body: block_ty=Int + [0] Expr ty=Int Call(Outer, arg_ty=Int) + callable Outer: input_ty=Int, output_ty=Int + body: block_ty=Int + [0] Local pat_ty=Int init_ty=Int Call(Inner, arg_ty=Int) + [1] Expr ty=Int Call(Inner, arg_ty=Int)"#]], + ); +} + +#[test] +fn mono_partial_application_skips_non_concrete_stdlib_generics() { + let source = indoc! {r#" + namespace Test { + import Std.Arrays.*; + import Std.Convert.*; + import Std.Diagnostics.*; + import Std.Intrinsic.*; + import Std.Math.*; + import Std.Measurement.*; + + @EntryPoint() + operation Main() : Result[] { + let secretBitString = SecretBitStringAsBoolArray(); + let parityOperation = EncodeBitStringAsParityOperation(secretBitString); + let decodedBitString = BernsteinVazirani( + parityOperation, + Length(secretBitString) + ); + + return decodedBitString; + } + + operation BernsteinVazirani(Uf : ((Qubit[], Qubit) => Unit), n : Int) : Result[] { + use queryRegister = Qubit[n]; + use target = Qubit(); + X(target); + within { + ApplyToEachA(H, queryRegister); + } apply { + H(target); + Uf(queryRegister, target); + } + let resultArray = MResetEachZ(queryRegister); + Reset(target); + return resultArray; + } + + operation ApplyParityOperation( + bitStringAsBoolArray : Bool[], + xRegister : Qubit[], + yQubit : Qubit + ) : Unit { + let requiredBits = Length(bitStringAsBoolArray); + let availableQubits = Length(xRegister); + Fact( + availableQubits >= requiredBits, + $"The bitstring has {requiredBits} bits but the quantum register " + $"only has {availableQubits} qubits" + ); + for (index, bit) in Enumerated(bitStringAsBoolArray) { + if bit { + CNOT(xRegister[index], yQubit); + } + } + } + + operation EncodeBitStringAsParityOperation(bitStringAsBoolArray : Bool[]) : (Qubit[], Qubit) => Unit { + return ApplyParityOperation(bitStringAsBoolArray, _, _); + } + + function SecretBitStringAsBoolArray() : Bool[] { + return [true, false, true, false, true]; + } + } + "#}; + + let (mut store, pkg_id) = crate::test_utils::compile_to_fir(source); + let mut assigner = Assigner::from_package(store.get(pkg_id)); + monomorphize(&mut store, pkg_id, &mut assigner); + let reachable = crate::reachability::collect_reachable_from_entry(&store, pkg_id); + let package = store.get(pkg_id); + let offenders = package + .items + .iter() + .filter(|(item_id, _)| { + reachable.contains(&qsc_fir::fir::StoreItemId { + package: pkg_id, + item: *item_id, + }) + }) + .filter_map(|(_, item)| { + let ItemKind::Callable(decl) = &item.kind else { + return None; + }; + + let input_ty = &package.get_pat(decl.input).ty; + let output_has_param = super::ty_contains_param(&decl.output); + let input_has_param = super::ty_contains_param(input_ty); + let functor_param = matches!(input_ty, qsc_fir::ty::Ty::Arrow(arrow) if matches!(arrow.functors, qsc_fir::ty::FunctorSet::Param(_))); + + (output_has_param || input_has_param || functor_param).then(|| { + format!( + "{}: generics={}, input={}, output={}", + decl.name.name, + decl.generics.len(), + input_ty, + decl.output, + ) + }) + }) + .collect::>(); + assert!( + offenders.is_empty(), + "offending callables after mono:\n{}", + offenders.join("\n") + ); + crate::invariants::check(&store, pkg_id, crate::invariants::InvariantLevel::PostMono); +} + +#[test] +fn mono_nested_depth_2() { + // A→B→C chain of generic calls. + check( + indoc! {r#" + operation C<'T>(x : 'T) : 'T { x } + operation B<'T>(x : 'T) : 'T { C(x) } + operation A<'T>(x : 'T) : 'T { B(x) } + operation Main() : Int { A(42) } + "#}, + &expect![[r#" + A: generics=1, input=Param<0>, output=Param<0> + A: generics=0, input=Int, output=Int + B: generics=1, input=Param<0>, output=Param<0> + B: generics=0, input=Int, output=Int + C: generics=1, input=Param<0>, output=Param<0> + C: generics=0, input=Int, output=Int + Main: generics=0, input=Unit, output=Int"#]], + ); +} + +#[test] +fn mono_nested_diamond() { + // Diamond: A calls B and C, both call D. + // D should be specialized only once. + check( + indoc! {r#" + operation D<'T>(x : 'T) : 'T { x } + operation B<'T>(x : 'T) : 'T { D(x) } + operation C<'T>(x : 'T) : 'T { D(x) } + operation A<'T>(x : 'T) : 'T { + let _ = B(x); + C(x) + } + operation Main() : Int { A(42) } + "#}, + &expect![[r#" + A: generics=1, input=Param<0>, output=Param<0> + A: generics=0, input=Int, output=Int + B: generics=1, input=Param<0>, output=Param<0> + B: generics=0, input=Int, output=Int + C: generics=1, input=Param<0>, output=Param<0> + C: generics=0, input=Int, output=Int + D: generics=1, input=Param<0>, output=Param<0> + D: generics=0, input=Int, output=Int + Main: generics=0, input=Unit, output=Int"#]], + ); +} + +#[test] +fn mono_arrow_param() { + // Generic callable with arrow-typed parameter. + check( + indoc! {r#" + operation ApplyOp<'T>(f : 'T => 'T, x : 'T) : 'T { f(x) } + operation DoubleInt(x : Int) : Int { x * 2 } + operation Main() : Int { ApplyOp(DoubleInt, 5) } + "#}, + &expect![[r#" + ApplyOp: generics=2, input=((Param<0> => Param<0> is 1), Param<0>), output=Param<0> + ApplyOp: generics=0, input=((Int => Int), Int), output=Int + DoubleInt: generics=0, input=Int, output=Int + Main: generics=0, input=Unit, output=Int"#]], + ); +} + +#[test] +fn mono_generic_with_body_locals() { + check( + indoc! {r#" + operation Transform<'T>(x : 'T) : 'T { + let tmp = x; + tmp + } + operation Main() : Int { Transform(42) } + "#}, + &expect![[r#" + Main: generics=0, input=Unit, output=Int + Transform: generics=1, input=Param<0>, output=Param<0> + Transform: generics=0, input=Int, output=Int"#]], + ); +} + +#[test] +fn mono_generic_preserves_local_chain() { + // Multiple local bindings chained together. + check( + indoc! {r#" + operation Chain<'T>(x : 'T) : 'T { + let a = x; + let b = a; + let c = b; + let d = c; + d + } + operation Main() : Int { Chain(42) } + "#}, + &expect![[r#" + Chain: generics=1, input=Param<0>, output=Param<0> + Chain: generics=0, input=Int, output=Int + Main: generics=0, input=Unit, output=Int"#]], + ); +} + +#[test] +fn mono_generic_with_ctl_spec() { + check( + indoc! {r#" + operation ApplyCtl<'T>(x : 'T) : Unit is Ctl { + body ... { } + controlled (ctls, ...) { } + } + operation Main() : Unit { + use q = Qubit(); + ApplyCtl(42); + } + "#}, + &expect![[r#" + ApplyCtl: generics=1, input=Param<0>, output=Unit + ApplyCtl: generics=0, input=Int, output=Unit + Main: generics=0, input=Unit, output=Unit"#]], + ); +} + +#[test] +fn mono_closure_in_generic() { + check( + indoc! {r#" + operation WithClosure<'T>(x : 'T) : 'T { + let f = (y) -> y; + f(x) + } + operation Main() : Int { WithClosure(42) } + "#}, + &expect![[r#" + : generics=0, input=(Int,), output=Int + : generics=0, input=(Param<0>,), output=Param<0> + Main: generics=0, input=Unit, output=Int + WithClosure: generics=1, input=Param<0>, output=Param<0> + WithClosure: generics=0, input=Int, output=Int"#]], + ); +} + +#[test] +fn mono_cross_package_length() { + // Length is a cross-package intrinsic generic callable in std. + check( + indoc! {r#" + operation Main() : Int { + let arr = [1, 2, 3]; + Length(arr) + } + "#}, + &expect![[r#" + Length: generics=0, input=(Int)[], output=Int + Main: generics=0, input=Unit, output=Int"#]], + ); +} + +#[test] +fn mono_cross_package_reversed() { + // Reversed is a cross-package generic callable. + check( + indoc! {r#" + operation Main() : Int[] { + let arr = [1, 2, 3]; + Microsoft.Quantum.Arrays.Reversed(arr) + } + "#}, + &expect![[r#" + Main: generics=0, input=Unit, output=(Int)[] + Reversed: generics=0, input=(Int)[], output=(Int)[]"#]], + ); +} + +#[test] +fn mono_cross_package_with_same_name() { + // Generic function uses same name as a cross-package generic callable. + check( + indoc! {r#" + function Reversed<'T>(array : 'T[]) : 'T[] { + Microsoft.Quantum.Arrays.Reversed(array) + } + operation Main() : Int[] { + let arr = [1, 2, 3]; + Reversed(arr) + } + "#}, + &expect![[r#" + Main: generics=0, input=Unit, output=(Int)[] + Reversed: generics=1, input=(Param<0>)[], output=(Param<0>)[] + Reversed: generics=0, input=(Int)[], output=(Int)[] + Reversed: generics=0, input=(Int)[], output=(Int)[]"#]], + ); +} + +#[test] +fn mono_identity_instantiation_not_duplicated() { + // When Outer<'T> calls Inner<'T>, the Inner reference is + // an identity instantiation. Only concrete instantiations (from the + // entry) should produce specializations. + check( + indoc! {r#" + operation Inner<'T>(x : 'T) : 'T { x } + operation Outer<'T>(x : 'T) : 'T { Inner(x) } + operation Main() : Int { Outer(42) } + "#}, + &expect![[r#" + Inner: generics=1, input=Param<0>, output=Param<0> + Inner: generics=0, input=Int, output=Int + Main: generics=0, input=Unit, output=Int + Outer: generics=1, input=Param<0>, output=Param<0> + Outer: generics=0, input=Int, output=Int"#]], + ); +} + +#[test] +fn mono_two_type_params() { + check( + indoc! {r#" + operation Pair<'A, 'B>(a : 'A, b : 'B) : 'A { a } + operation Main() : Int { + use q = Qubit(); + Pair(42, q) + } + "#}, + &expect![[r#" + Main: generics=0, input=Unit, output=Int + Pair: generics=2, input=(Param<0>, Param<1>), output=Param<0> + Pair: generics=0, input=(Int, Qubit), output=Int"#]], + ); +} + +#[test] +fn mono_specialized_callable_node_ids_do_not_collide_with_spec_nodes() { + let (mut store, pkg_id) = crate::test_utils::compile_to_fir(indoc! {r#" + operation ApplyCtl<'T>(x : 'T) : Unit is Ctl { + body ... { } + controlled (ctls, ...) { } + } + operation Main() : Unit { + ApplyCtl(42); + } + "#}); + let mut assigner = Assigner::from_package(store.get(pkg_id)); + monomorphize(&mut store, pkg_id, &mut assigner); + + let package = store.get(pkg_id); + let mut seen = FxHashSet::default(); + for item in package.items.values() { + let ItemKind::Callable(decl) = &item.kind else { + continue; + }; + assert_node_id_is_unique(decl.id, &mut seen); + match &decl.implementation { + CallableImpl::Spec(spec_impl) => { + assert_node_id_is_unique(spec_impl.body.id, &mut seen); + for spec in crate::fir_builder::functored_specs(spec_impl) { + assert_node_id_is_unique(spec.id, &mut seen); + } + } + CallableImpl::SimulatableIntrinsic(spec) => { + assert_node_id_is_unique(spec.id, &mut seen); + } + CallableImpl::Intrinsic => {} + } + } +} + +#[test] +#[should_panic( + expected = "Non-intrinsic same-package callable has no monomorphized specialization" +)] +fn mono_missing_same_package_specialization_panics() { + let (mut store, pkg_id) = crate::test_utils::compile_to_fir(indoc! {r#" + function Identity<'T>(x : 'T) : 'T { x } + function Main() : Int { Identity(42) } + "#}); + + let expr_ids: Vec<_> = store.get(pkg_id).exprs.iter().map(|(id, _)| id).collect(); + rewrite_call_sites(store.get_mut(pkg_id), pkg_id, &[], &expr_ids); +} + +fn assert_node_id_is_unique(node_id: NodeId, seen: &mut FxHashSet) { + assert!( + seen.insert(u32::from(node_id)), + "NodeId {node_id:?} should be unique after monomorphization" + ); +} + +#[test] +fn mono_recursive_generic() { + // Recursive generic callable — self-references should be rewritten + // to point at the specialized clone. + check( + indoc! {r#" + operation Repeat<'T>(x : 'T, n : Int) : 'T { + if n <= 0 { + x + } else { + Repeat(x, n - 1) + } + } + operation Main() : Int { Repeat(42, 3) } + "#}, + &expect![[r#" + Main: generics=0, input=Unit, output=Int + Repeat: generics=1, input=(Param<0>, Int), output=Param<0> + Repeat: generics=0, input=(Int, Int), output=Int"#]], + ); +} + +#[test] +fn mono_invariants_hold_post_pass() { + let (store, pkg_id) = crate::test_utils::compile_and_run_pipeline_to( + indoc! {r#" + operation Identity<'T>(x : 'T) : 'T { x } + operation Outer<'T>(x : 'T) : 'T { Identity(x) } + operation Main() : Int { Outer(42) } + "#}, + crate::test_utils::PipelineStage::Mono, + ); + // If we reach here, the invariant check inside + // compile_and_run_pipeline_to already passed. + let _ = (store, pkg_id); +} + +#[test] +fn mono_generic_with_simulatable_intrinsic() { + // A generic function used via a simulatable intrinsic path. + // Length is a cross-package intrinsic: verify it's specialized. + check( + indoc! {r#" + operation Wrap<'T>(arr : 'T[]) : Int { Length(arr) } + operation Main() : Int { + Wrap([1, 2, 3]) + } + "#}, + &expect![[r#" + Length: generics=0, input=(Int)[], output=Int + Main: generics=0, input=Unit, output=Int + Wrap: generics=1, input=(Param<0>)[], output=Int + Wrap: generics=0, input=(Int)[], output=Int"#]], + ); +} + +#[test] +fn mono_generic_with_functor_param() { + // Generic callable with a functor-parameterized operation parameter. + check( + indoc! {r#" + operation RunOp<'T>(op : 'T => Unit, x : 'T) : Unit { op(x) } + operation NoOp(x : Int) : Unit {} + operation Main() : Unit { RunOp(NoOp, 42) } + "#}, + &expect![[r#" + Main: generics=0, input=Unit, output=Unit + NoOp: generics=0, input=Int, output=Unit + RunOp: generics=2, input=((Param<0> => Unit is 1), Param<0>), output=Unit + RunOp: generics=0, input=((Int => Unit), Int), output=Unit"#]], + ); +} + +#[test] +fn mono_functor_specialized_clone_preserves_explicit_specs() { + check_details( + indoc! {r#" + operation ApplyOp<'T>(op : 'T => Unit is Adj + Ctl, x : 'T) : Unit is Adj + Ctl { + body ... { op(x); } + adjoint ... { Adjoint op(x); } + controlled (ctls, ...) { Controlled op(ctls, x); } + controlled adjoint (ctls, ...) { Controlled Adjoint op(ctls, x); } + } + operation Main() : Unit { + use q = Qubit(); + ApplyOp(S, q); + } + "#}, + &expect![[r#" + callable ApplyOp: input_ty=((Qubit => Unit is Adj + Ctl), Qubit), output_ty=Unit + body: block_ty=Unit + [0] Semi ty=Unit Call(Local(op), arg_ty=Qubit) + adj: block_ty=Unit + [0] Semi ty=Unit Call(Functor Adj(Local(op)), arg_ty=Qubit) + ctl: block_ty=Unit + [0] Semi ty=Unit Call(Functor Ctl(Local(op)), arg_ty=((Qubit)[], Qubit)) + ctl_adj: block_ty=Unit + [0] Semi ty=Unit Call(Functor Ctl(Functor Adj(Local(op))), arg_ty=((Qubit)[], Qubit)) + callable Main: input_ty=Unit, output_ty=Unit + body: block_ty=Unit + [0] Local pat_ty=Qubit init_ty=Qubit Call(Item(Item 8 (Package 0)), arg_ty=Unit) + [1] Semi ty=Unit Call(ApplyOp, arg_ty=((Qubit => Unit is Adj + Ctl), Qubit)) + [2] Semi ty=Unit Call(Item(Item 10 (Package 0)), arg_ty=Qubit)"#]], + ); +} + +#[test] +fn mono_generic_with_adj_ctl_specs_in_body() { + // Generic operation with adjoint + controlled specs. + check( + indoc! {r#" + operation DoIt<'T>(x : 'T) : Unit is Adj + Ctl { + body ... { } + adjoint self; + controlled (ctls, ...) { } + controlled adjoint self; + } + operation Main() : Unit { + DoIt(42); + } + "#}, + &expect![[r#" + DoIt: generics=1, input=Param<0>, output=Unit + DoIt: generics=0, input=Int, output=Unit + Main: generics=0, input=Unit, output=Unit"#]], + ); +} + +#[test] +fn mono_generic_captures_variable() { + // A closure inside a generic callable captures a variable typed with + // the generic parameter. + check( + indoc! {r#" + operation WithCapture<'T>(x : 'T) : 'T { + let captured = x; + let f = () -> captured; + f() + } + operation Main() : Int { WithCapture(42) } + "#}, + &expect![[r#" + : generics=0, input=(Int, Unit), output=Int + : generics=0, input=(Param<0>, Unit), output=Param<0> + Main: generics=0, input=Unit, output=Int + WithCapture: generics=1, input=Param<0>, output=Param<0> + WithCapture: generics=0, input=Int, output=Int"#]], + ); +} + +#[test] +fn mono_generic_array_of_type_param() { + // Generic callable taking an array of the type parameter. + check( + indoc! {r#" + operation First<'T>(arr : 'T[]) : 'T { arr[0] } + operation Main() : Int { First([10, 20, 30]) } + "#}, + &expect![[r#" + First: generics=1, input=(Param<0>)[], output=Param<0> + First: generics=0, input=(Int)[], output=Int + Main: generics=0, input=Unit, output=Int"#]], + ); +} + +#[test] +fn mono_generic_nested_tuple_types() { + // Generic callable returning a nested tuple containing the type param. + check( + indoc! {r#" + operation Nest<'T>(x : 'T) : (('T, Int), Bool) { ((x, 0), true) } + operation Main() : ((Int, Int), Bool) { Nest(42) } + "#}, + &expect![[r#" + Main: generics=0, input=Unit, output=((Int, Int), Bool) + Nest: generics=1, input=Param<0>, output=((Param<0>, Int), Bool) + Nest: generics=0, input=Int, output=((Int, Int), Bool)"#]], + ); +} + +#[test] +fn mono_mutual_recursion_different_types() { + // Two mutually recursive generic callables with the same type parameter. + check( + indoc! {r#" + operation Ping<'T>(x : 'T, n : Int) : 'T { + if n <= 0 { x } else { Pong(x, n - 1) } + } + operation Pong<'T>(x : 'T, n : Int) : 'T { + Ping(x, n) + } + operation Main() : Int { Ping(42, 2) } + "#}, + &expect![[r#" + Main: generics=0, input=Unit, output=Int + Ping: generics=1, input=(Param<0>, Int), output=Param<0> + Ping: generics=0, input=(Int, Int), output=Int + Pong: generics=1, input=(Param<0>, Int), output=Param<0> + Pong: generics=0, input=(Int, Int), output=Int"#]], + ); +} + +#[test] +fn mono_generic_with_adj_spec_only() { + // Generic operation with adjoint-only functor specification. + check( + indoc! {r#" + operation MyAdj<'T>(x : 'T) : Unit is Adj { + body ... { } + adjoint self; + } + operation Main() : Unit { + MyAdj(42); + Adjoint MyAdj(42); + } + "#}, + &expect![[r#" + Main: generics=0, input=Unit, output=Unit + MyAdj: generics=1, input=Param<0>, output=Unit + MyAdj: generics=0, input=Int, output=Unit"#]], + ); +} + +#[test] +fn mutual_recursion_between_generics_specializes_both() { + // Two mutually recursive generic functions: IsEven<'T> calls IsOdd<'T> + // and vice versa. Both should be specialized for Int. + let source = indoc! {r#" + function IsEven<'T>(n : Int, val : 'T) : Bool { + if n == 0 { true } else { IsOdd(n - 1, val) } + } + + function IsOdd<'T>(n : Int, val : 'T) : Bool { + if n == 0 { false } else { IsEven(n - 1, val) } + } + + function Main() : Bool { + IsEven(4, 0) + } + "#}; + check( + source, + &expect![[r#" + IsEven: generics=1, input=(Int, Param<0>), output=Bool + IsEven: generics=0, input=(Int, Int), output=Bool + IsOdd: generics=1, input=(Int, Param<0>), output=Bool + IsOdd: generics=0, input=(Int, Int), output=Bool + Main: generics=0, input=Unit, output=Bool"#]], + ); + // Verify PostMono invariants hold (no Ty::Param remaining). + let _ = crate::test_utils::compile_and_run_pipeline_to( + source, + crate::test_utils::PipelineStage::Mono, + ); +} + +#[test] +fn deeply_nested_generic_args_specialize_correctly() { + // Generic callable instantiated with a complex nested type arg: + // (Int, Double) as the type parameter. + check( + indoc! {r#" + function Wrap<'T>(val : 'T) : 'T[] { + [val] + } + + function Main() : (Int, Double)[] { + Wrap((1, 2.0)) + } + "#}, + &expect![[r#" + Main: generics=0, input=Unit, output=((Int, Double))[] + Wrap: generics=1, input=Param<0>, output=(Param<0>)[] + Wrap<(Int, Double)>: generics=0, input=(Int, Double), output=((Int, Double))[]"#]], + ); +} + +#[test] +fn cross_package_non_intrinsic_generic_specializes() { + // Enumerated is a non-intrinsic cross-package generic that returns + // (Int, 'TElement)[] — structurally different output type from + // Reversed, and internally chains through MappedByIndex. + check( + indoc! {r#" + function Main() : (Int, Int)[] { + Microsoft.Quantum.Arrays.Enumerated([10, 20, 30]) + } + "#}, + &expect![[r#" + : generics=0, input=((Int, Int),), output=(Int, Int) + Enumerated: generics=0, input=(Int)[], output=((Int, Int))[] + Length: generics=0, input=(Int)[], output=Int + Main: generics=0, input=Unit, output=((Int, Int))[] + MappedByIndex: generics=0, input=(((Int, Int) -> (Int, Int)), (Int)[]), output=((Int, Int))[]"#]], + ); +} + +#[test] +#[should_panic(expected = "monomorphize requires a package entry expression")] +fn monomorphize_no_entry_panics() { + // Compile as a library (no @EntryPoint) so package.entry is None. + // monomorphize should panic because it requires an entry expression. + use qsc_data_structures::{ + language_features::LanguageFeatures, source::SourceMap, target::TargetCapabilityFlags, + }; + use qsc_frontend::compile as frontend_compile; + use qsc_hir::hir::PackageId as HirPackageId; + use qsc_passes::{PackageType, lower_hir_to_fir, run_core_passes, run_default_passes}; + + let mut core_unit = frontend_compile::core(); + let core_errors = run_core_passes(&mut core_unit); + assert!(core_errors.is_empty()); + let mut hir_store = frontend_compile::PackageStore::new(core_unit); + + let mut std_unit = frontend_compile::std(&hir_store, TargetCapabilityFlags::empty()); + let std_errors = run_default_passes(hir_store.core(), &mut std_unit, PackageType::Lib); + assert!(std_errors.is_empty()); + hir_store.insert(std_unit); + + let std_id = HirPackageId::CORE.successor(); + let sources = SourceMap::new( + vec![( + "lib.qs".into(), + "function Helper<'T>(x : 'T) : 'T { x }".into(), + )], + None, + ); + let mut unit = frontend_compile::compile( + &hir_store, + &[(HirPackageId::CORE, None), (std_id, None)], + sources, + TargetCapabilityFlags::empty(), + LanguageFeatures::default(), + ); + crate::test_utils::assert_no_compile_errors("user code", &unit.errors); + let pass_errors = run_default_passes(hir_store.core(), &mut unit, PackageType::Lib); + assert!(pass_errors.is_empty()); + let hir_pkg_id = hir_store.insert(unit); + let (mut fir_store, fir_pkg_id, _) = lower_hir_to_fir(&hir_store, hir_pkg_id); + + assert!(fir_store.get(fir_pkg_id).entry.is_none()); + + let mut assigner = Assigner::from_package(fir_store.get(fir_pkg_id)); + monomorphize(&mut fir_store, fir_pkg_id, &mut assigner); +} + +#[test] +fn mono_preserves_simulatable_intrinsic_impl() { + // A generic @SimulatableIntrinsic callable should, after monomorphization, + // produce a specialization that retains the SimulatableIntrinsic variant. + let (mut store, pkg_id) = crate::test_utils::compile_to_fir(indoc! {r#" + @SimulatableIntrinsic() + operation MySimIntrinsic<'T>(x : 'T) : 'T { x } + operation Main() : Int { MySimIntrinsic(42) } + "#}); + let mut assigner = Assigner::from_package(store.get(pkg_id)); + monomorphize(&mut store, pkg_id, &mut assigner); + + let package = store.get(pkg_id); + let mut found_specialized = false; + for (_, item) in &package.items { + if let ItemKind::Callable(decl) = &item.kind + && decl.name.name.as_ref() == "MySimIntrinsic" + { + assert!( + matches!(decl.implementation, CallableImpl::SimulatableIntrinsic(_)), + "specialized callable should preserve SimulatableIntrinsic variant" + ); + assert!( + decl.generics.is_empty(), + "specialized callable should have no generic params" + ); + found_specialized = true; + } + } + assert!( + found_specialized, + "should find a specialized MySimIntrinsic callable" + ); +} + +#[test] +fn monomorphize_is_idempotent() { + let source = indoc! {r#" + operation Identity<'T>(x : 'T) : 'T { x } + operation Main() : Int { Identity(42) } + "#}; + let (mut store, pkg_id) = crate::test_utils::compile_and_run_pipeline_to( + source, + crate::test_utils::PipelineStage::Mono, + ); + let first = crate::pretty::write_package_qsharp(&store, pkg_id); + let mut assigner = Assigner::from_package(store.get(pkg_id)); + monomorphize(&mut store, pkg_id, &mut assigner); + let second = crate::pretty::write_package_qsharp(&store, pkg_id); + assert_eq!(first, second, "monomorphize should be idempotent"); +} + +fn render_before_after_mono(source: &str) -> (String, String) { + let (mut store, pkg_id) = crate::test_utils::compile_to_fir(source); + let before = crate::pretty::write_package_qsharp(&store, pkg_id); + let mut assigner = Assigner::from_package(store.get(pkg_id)); + monomorphize(&mut store, pkg_id, &mut assigner); + let after = crate::pretty::write_package_qsharp(&store, pkg_id); + (before, after) +} + +fn check_before_after(source: &str, expect: &Expect) { + let (before, after) = render_before_after_mono(source); + expect.assert_eq(&format!("BEFORE:\n{before}\nAFTER:\n{after}")); +} + +#[test] +fn before_after_generic_specialization() { + check_before_after( + indoc! {r#" + operation Identity<'T>(x : 'T) : 'T { x } + operation Main() : Int { Identity(42) } + "#}, + &expect![[r#" + BEFORE: + // namespace test + operation Identity<''T > (x : 'T0) : 'T0 { + body { + x + } + } + operation Main() : Int { + body { + Identity < Int > (42) + } + } + // entry + Main() + + AFTER: + // namespace test + operation Identity<''T > (x : 'T0) : 'T0 { + body { + x + } + } + operation Main() : Int { + body { + Identity < Int > (42) + } + } + operation Identity(x : Int) : Int { + body { + x + } + } + // entry + Main() + "#]], // snapshot populated by UPDATE_EXPECT=1 + ); +} + +#[test] +fn shared_input_and_arrow_generic_param_specializes() { + check_before_after( + indoc! {r#" + function double<'T: Add>(x : 'T) : 'T { x + x } + function doDouble<'T>(a : 'T, doubler : ('T -> 'T)) : 'T { doubler(a) } + operation Main() : Unit { + use q = Qubit(); + if M(q) == One { + doDouble(3, double); + } else { + doDouble(3.0, double); + } + } + "#}, + &expect![[r#" + BEFORE: + // namespace test + function double<''T > (x : 'T0) : 'T0 { + body { + x + x + } + } + function doDouble<''T > (a : 'T0, doubler : ('T0 -> 'T0)) : 'T0 { + body { + doubler(a) + } + } + operation Main() : Unit { + body { + let q : Qubit = __quantum__rt__qubit_allocate(); + let + @generated_ident_64 : Unit = if M(q) == One { + doDouble < Int > (3, double < Int >); + } else { + doDouble < Double > (3., double < Double >); + }; + __quantum__rt__qubit_release(q); + @generated_ident_64 + } + } + // entry + Main() + + AFTER: + // namespace test + function double<''T > (x : 'T0) : 'T0 { + body { + x + x + } + } + function doDouble<''T > (a : 'T0, doubler : ('T0 -> 'T0)) : 'T0 { + body { + doubler(a) + } + } + operation Main() : Unit { + body { + let q : Qubit = __quantum__rt__qubit_allocate(); + let + @generated_ident_64 : Unit = if M(q) == One { + doDouble < Int > (3, double < Int >); + } else { + doDouble < Double > (3., double < Double >); + }; + __quantum__rt__qubit_release(q); + @generated_ident_64 + } + } + function Length(a : Qubit[]) : Int { + body intrinsic; + } + function Length(a : Pauli[]) : Int { + body intrinsic; + } + function doDouble(a : Int, doubler : (Int -> Int)) : Int { + body { + doubler(a) + } + } + function double(x : Int) : Int { + body { + x + x + } + } + function doDouble(a : Double, doubler : (Double -> Double)) : Double { + body { + doubler(a) + } + } + function double(x : Double) : Double { + body { + x + x + } + } + // entry + Main() + "#]], + ); +} + +#[test] +fn unreachable_generic_call_site_not_specialized() { + // Monomorphize only processes reachable callables. + // The dead callable's generic call with a different type arg + // never generates a specialization. Verify that only the reachable + // Int specialization is produced. + check( + indoc! {" + namespace Test { + @EntryPoint() + function Main() : Int { + Identity(42) + } + function Identity<'T>(x : 'T) : 'T { x } + } + "}, + &expect![[r#" + Identity: generics=1, input=Param<0>, output=Param<0> + Identity: generics=0, input=Int, output=Int + Main: generics=0, input=Unit, output=Int"#]], + ); +} diff --git a/source/compiler/qsc_fir_transforms/src/pretty.rs b/source/compiler/qsc_fir_transforms/src/pretty.rs new file mode 100644 index 0000000000..bcfea85760 --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/pretty.rs @@ -0,0 +1,1153 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! FIR-to-Q# pretty-printer for pass debugging. +//! +//! Walks FIR structures via [`PackageLookup`]/[`PackageStoreLookup`] and +//! writes lexically valid Q# with minimal whitespace, then runs +//! [`qsc_formatter::formatter::format_str`] over the raw output. +//! +//! The emitter is intended for before/after snapshot tests of FIR +//! transform passes. It is best-effort — some FIR-only constructs render +//! as Q# comments or synthetic surface syntax: +//! +//! - [`ExprKind::Closure`] → `/* closure item= captures=[] */` +//! followed by a reference to the lifted callable item. +//! - [`ExprKind::ArrayLit`] renders with the same surface as +//! [`ExprKind::Array`]. +//! - [`ExprKind::AssignField`] / [`ExprKind::AssignIndex`] / +//! [`ExprKind::UpdateField`] / [`ExprKind::UpdateIndex`] render via the +//! idiomatic `r w/= F <- v` / `r w/ F <- v` forms. +//! - [`Field::Path`] chains indices as `::Item::Item` when UDT +//! metadata is not available; otherwise field names resolve through the +//! owning [`Udt`]. +//! - [`Ty::Prim`] renders via [`prim_as_qsharp`]. +//! +//! # Borrow strategy +//! +//! Walking the FIR requires shared borrows through [`PackageLookup`] while +//! also mutating the output buffer. The emitter resolves this by *cloning* +//! the FIR node kind at every traversal boundary (the nodes are cheap +//! struct/enum types) before calling back into `&mut self` helpers. + +#[cfg(test)] +mod tests; + +use qsc_fir::fir::{ + BinOp, BlockId, CallableDecl, CallableImpl, CallableKind, ExprId, ExprKind, Field, FieldAssign, + FieldPath, Functor, ItemId, ItemKind, Lit, LocalItemId, LocalVarId, Mutability, Package, + PackageId, PackageLookup, PackageStore, PackageStoreLookup, PatId, PatKind, Pauli, PrimField, + Res, Result as FirResult, SpecDecl, StmtId, StmtKind, StoreItemId, StringComponent, UnOp, +}; +use qsc_fir::ty::{Arrow, FunctorSet, FunctorSetValue, GenericArg, Prim, Ty, TypeParameter, Udt}; +use qsc_formatter::formatter::format_str; +use rustc_hash::FxHashMap; +use std::fmt::Write as _; +use std::rc::Rc; + +#[derive(Clone, Copy, Eq, PartialEq)] +enum RenderMode { + Debug, + Parseable, +} + +/// Renders the full FIR package as Q# source. +#[must_use] +pub fn write_package_qsharp(store: &PackageStore, package_id: PackageId) -> String { + let mut emitter = FirQSharpGen::new(store, package_id); + emitter.emit_package(); + format_str(&emitter.output) +} + +#[cfg(test)] +#[must_use] +pub(crate) fn write_package_qsharp_parseable( + store: &PackageStore, + package_id: PackageId, +) -> String { + let mut emitter = FirQSharpGen::new_with_mode(store, package_id, RenderMode::Parseable); + emitter.emit_package(); + format_str(&emitter.output) +} + +/// Renders a single callable item as Q# source. +#[must_use] +pub fn write_callable_qsharp( + store: &PackageStore, + package_id: PackageId, + item: LocalItemId, +) -> String { + let mut emitter = FirQSharpGen::new(store, package_id); + let decl = match &emitter.package().get_item(item).kind { + ItemKind::Callable(decl) => Some((**decl).clone()), + _ => None, + }; + if let Some(decl) = decl { + emitter.emit_callable_decl(&decl); + } + format_str(&emitter.output) +} + +/// Renders a single block as Q# source. +#[must_use] +pub fn write_block_qsharp(store: &PackageStore, package_id: PackageId, block: BlockId) -> String { + let mut emitter = FirQSharpGen::new(store, package_id); + emitter.emit_block(block); + format_str(&emitter.output) +} + +/// Renders a single expression as Q# source. +#[must_use] +pub fn write_expr_qsharp(store: &PackageStore, package_id: PackageId, expr: ExprId) -> String { + let mut emitter = FirQSharpGen::new(store, package_id); + emitter.emit_expr(expr); + format_str(&emitter.output) +} + +/// Renders a single statement as Q# source. +#[must_use] +pub fn write_stmt_qsharp(store: &PackageStore, package_id: PackageId, stmt: StmtId) -> String { + let mut emitter = FirQSharpGen::new(store, package_id); + emitter.emit_stmt(stmt); + format_str(&emitter.output) +} + +struct FirQSharpGen<'a> { + output: String, + store: &'a PackageStore, + package_id: PackageId, + local_names: FxHashMap>, + mode: RenderMode, +} + +impl<'a> FirQSharpGen<'a> { + fn new(store: &'a PackageStore, package_id: PackageId) -> Self { + Self::new_with_mode(store, package_id, RenderMode::Debug) + } + + fn new_with_mode(store: &'a PackageStore, package_id: PackageId, mode: RenderMode) -> Self { + Self { + output: String::new(), + store, + package_id, + local_names: FxHashMap::default(), + mode, + } + } + + fn package(&self) -> &Package { + self.store.get(self.package_id) + } + + fn write(&mut self, s: &str) { + self.output.push_str(s); + } + + fn writeln(&mut self, s: &str) { + self.output.push_str(s); + self.output.push('\n'); + } + + fn emit_package(&mut self) { + let ids: Vec = self.package().items.values().map(|i| i.id).collect(); + for id in ids { + self.emit_item(id); + } + let entry = self.package().entry; + if let Some(e) = entry { + self.writeln("// entry"); + self.emit_expr(e); + self.writeln(""); + } + } + + fn emit_item(&mut self, id: LocalItemId) { + let kind = self.package().get_item(id).kind.clone(); + match kind { + ItemKind::Callable(decl) => self.emit_callable_decl(&decl), + ItemKind::Namespace(name, _) => { + self.write("// namespace "); + self.write(&name.name); + self.writeln(""); + } + ItemKind::Ty(name, udt) => { + let ty = udt.get_pure_ty(); + self.write("newtype "); + self.write(&name.name); + self.write(" = "); + self.emit_ty(&ty); + self.writeln(";"); + } + ItemKind::Export(name, res) => { + self.write("// export "); + self.write(&name.name); + self.write(" = "); + self.emit_res(&res); + self.writeln(""); + } + } + } + + fn emit_callable_decl(&mut self, decl: &CallableDecl) { + let local_names = self.local_names_for_callable(decl); + let previous_local_names = std::mem::replace(&mut self.local_names, local_names); + + match decl.kind { + CallableKind::Function => self.write("function "), + CallableKind::Operation => self.write("operation "), + } + self.write(&self.render_ident(&decl.name.name)); + if !decl.generics.is_empty() { + self.write("<"); + for (i, g) in decl.generics.iter().enumerate() { + if i > 0 { + self.write(", "); + } + self.write(&type_parameter_name(g)); + } + self.write(">"); + } + self.emit_callable_input_pat(decl.input); + self.write(" : "); + self.emit_ty(&decl.output); + if decl.functors != FunctorSetValue::Empty { + self.write(" is "); + self.write(functor_set_value_as_str(decl.functors)); + } + + // Future optimization: omit the body label and braces when only a body exists. + + match &decl.implementation { + CallableImpl::Intrinsic => { + self.writeln(" { body intrinsic; }"); + } + CallableImpl::Spec(spec) => { + let body = spec.body.clone(); + let adj = spec.adj.clone(); + let ctl = spec.ctl.clone(); + let ctl_adj = spec.ctl_adj.clone(); + if self.mode == RenderMode::Parseable + && adj.is_none() + && ctl.is_none() + && ctl_adj.is_none() + { + self.emit_block(body.block); + self.local_names = previous_local_names; + return; + } + self.writeln(" {"); + self.emit_spec_decl("body", &body); + if let Some(s) = adj { + self.emit_spec_decl("adjoint", &s); + } + if let Some(s) = ctl { + self.emit_spec_decl("controlled", &s); + } + if let Some(s) = ctl_adj { + self.emit_spec_decl("controlled adjoint", &s); + } + self.writeln("}"); + } + CallableImpl::SimulatableIntrinsic(spec) => { + let spec = spec.clone(); + self.writeln(" {"); + self.emit_spec_decl("body", &spec); + self.writeln("}"); + } + } + + self.local_names = previous_local_names; + } + + fn local_names_for_callable(&self, decl: &CallableDecl) -> FxHashMap> { + let mut local_names = FxHashMap::default(); + self.collect_pat_names(decl.input, &mut local_names); + if self.mode == RenderMode::Parseable { + match &decl.implementation { + CallableImpl::Intrinsic => {} + CallableImpl::Spec(spec) => { + for spec in std::iter::once(&spec.body) + .chain(spec.adj.iter()) + .chain(spec.ctl.iter()) + .chain(spec.ctl_adj.iter()) + { + if let Some(input_pat) = spec.input { + self.collect_pat_names(input_pat, &mut local_names); + } + } + } + CallableImpl::SimulatableIntrinsic(spec) => { + if let Some(input_pat) = spec.input { + self.collect_pat_names(input_pat, &mut local_names); + } + } + } + } + self.collect_impl_local_names(&decl.implementation, &mut local_names); + local_names + } + + fn collect_impl_local_names( + &self, + implementation: &CallableImpl, + local_names: &mut FxHashMap>, + ) { + match implementation { + CallableImpl::Intrinsic => {} + CallableImpl::Spec(spec) => { + self.collect_spec_decl_local_names(&spec.body, local_names); + if let Some(adj) = &spec.adj { + self.collect_spec_decl_local_names(adj, local_names); + } + if let Some(ctl) = &spec.ctl { + self.collect_spec_decl_local_names(ctl, local_names); + } + if let Some(ctl_adj) = &spec.ctl_adj { + self.collect_spec_decl_local_names(ctl_adj, local_names); + } + } + CallableImpl::SimulatableIntrinsic(spec) => { + self.collect_spec_decl_local_names(spec, local_names); + } + } + } + + fn collect_spec_decl_local_names( + &self, + spec: &SpecDecl, + local_names: &mut FxHashMap>, + ) { + self.collect_block_local_names(spec.block, local_names); + } + + fn collect_block_local_names( + &self, + block_id: BlockId, + local_names: &mut FxHashMap>, + ) { + for &stmt_id in &self.package().get_block(block_id).stmts { + let stmt = self.package().get_stmt(stmt_id); + match &stmt.kind { + StmtKind::Expr(expr) | StmtKind::Semi(expr) => { + self.collect_expr_local_names(*expr, local_names); + } + StmtKind::Local(_, pat_id, expr) => { + self.collect_pat_names(*pat_id, local_names); + self.collect_expr_local_names(*expr, local_names); + } + StmtKind::Item(_) => {} + } + } + } + + fn collect_expr_local_names( + &self, + expr_id: ExprId, + local_names: &mut FxHashMap>, + ) { + let kind = &self.package().get_expr(expr_id).kind; + match kind { + ExprKind::Array(exprs) | ExprKind::ArrayLit(exprs) | ExprKind::Tuple(exprs) => { + for &expr in exprs { + self.collect_expr_local_names(expr, local_names); + } + } + ExprKind::ArrayRepeat(item, size) + | ExprKind::Assign(item, size) + | ExprKind::AssignOp(_, item, size) + | ExprKind::BinOp(_, item, size) + | ExprKind::Call(item, size) + | ExprKind::Index(item, size) + | ExprKind::AssignField(item, _, size) + | ExprKind::UpdateField(item, _, size) => { + self.collect_expr_local_names(*item, local_names); + self.collect_expr_local_names(*size, local_names); + } + ExprKind::AssignIndex(array, index, value) + | ExprKind::UpdateIndex(array, index, value) => { + self.collect_expr_local_names(*array, local_names); + self.collect_expr_local_names(*index, local_names); + self.collect_expr_local_names(*value, local_names); + } + ExprKind::Block(block) => self.collect_block_local_names(*block, local_names), + ExprKind::Closure(_, _) | ExprKind::Hole | ExprKind::Lit(_) | ExprKind::Var(_, _) => {} + ExprKind::Fail(expr) + | ExprKind::Field(expr, _) + | ExprKind::Return(expr) + | ExprKind::UnOp(_, expr) => self.collect_expr_local_names(*expr, local_names), + ExprKind::If(cond, body, otherwise) => { + self.collect_expr_local_names(*cond, local_names); + self.collect_expr_local_names(*body, local_names); + if let Some(otherwise) = otherwise { + self.collect_expr_local_names(*otherwise, local_names); + } + } + ExprKind::Range(start, step, end) => { + for expr in [start, step, end].into_iter().flatten() { + self.collect_expr_local_names(*expr, local_names); + } + } + ExprKind::Struct(_, copy, fields) => { + if let Some(copy) = copy { + self.collect_expr_local_names(*copy, local_names); + } + for field in fields { + self.collect_expr_local_names(field.value, local_names); + } + } + ExprKind::String(components) => { + for component in components { + if let StringComponent::Expr(expr) = component { + self.collect_expr_local_names(*expr, local_names); + } + } + } + ExprKind::While(cond, block) => { + self.collect_expr_local_names(*cond, local_names); + self.collect_block_local_names(*block, local_names); + } + } + } + + fn collect_pat_names(&self, pat_id: PatId, local_names: &mut FxHashMap>) { + let pat = self.package().get_pat(pat_id); + match &pat.kind { + PatKind::Bind(ident) => { + local_names.insert(ident.id, Rc::from(self.render_ident(&ident.name))); + } + PatKind::Tuple(pats) => { + for &pat in pats { + self.collect_pat_names(pat, local_names); + } + } + PatKind::Discard => {} + } + } + + fn emit_spec_decl(&mut self, label: &str, spec: &SpecDecl) { + if self.mode == RenderMode::Parseable { + self.emit_parseable_spec_decl(label, spec); + return; + } + self.write(label); + self.emit_block(spec.block); + } + + fn emit_parseable_spec_decl(&mut self, label: &str, spec: &SpecDecl) { + self.write(label); + match label { + "body" | "adjoint" => { + self.write(" ..."); + } + "controlled" | "controlled adjoint" => { + if let Some(input_pat) = spec.input { + self.write(" ("); + self.emit_pat_bindings(self.control_pat(input_pat)); + self.write(", ...)"); + } + } + _ => {} + } + self.emit_block(spec.block); + } + + fn control_pat(&self, input_pat: PatId) -> PatId { + let pat = self.package().get_pat(input_pat); + match &pat.kind { + PatKind::Tuple(pats) => pats.first().copied().unwrap_or(input_pat), + PatKind::Bind(_) | PatKind::Discard => input_pat, + } + } + + fn emit_block(&mut self, block_id: BlockId) { + let stmts = self.package().get_block(block_id).stmts.clone(); + self.writeln(" {"); + for stmt in stmts { + self.emit_stmt(stmt); + } + self.writeln("}"); + } + + fn emit_stmt(&mut self, stmt_id: StmtId) { + let kind = self.package().get_stmt(stmt_id).kind.clone(); + match kind { + StmtKind::Expr(e) => { + self.emit_expr(e); + self.writeln(""); + } + StmtKind::Semi(e) => { + self.emit_expr(e); + self.writeln(";"); + } + StmtKind::Local(mutability, pat_id, expr) => { + match mutability { + Mutability::Immutable => self.write("let "), + Mutability::Mutable => self.write("mutable "), + } + self.emit_pat(pat_id); + self.write(" = "); + self.emit_expr(expr); + self.writeln(";"); + } + StmtKind::Item(item_id) => { + self.write("// item "); + self.write(&format!("{item_id}")); + self.writeln(""); + } + } + } + + fn emit_expr(&mut self, expr_id: ExprId) { + let kind = self.package().get_expr(expr_id).kind.clone(); + self.emit_expr_kind(&kind); + } + + #[allow(clippy::too_many_lines)] + fn emit_expr_kind(&mut self, kind: &ExprKind) { + match kind { + ExprKind::Array(exprs) | ExprKind::ArrayLit(exprs) => { + self.write("["); + self.emit_comma_separated_exprs(exprs); + self.write("]"); + } + ExprKind::ArrayRepeat(item, size) => { + self.write("["); + self.emit_expr(*item); + self.write(", size = "); + self.emit_expr(*size); + self.write("]"); + } + ExprKind::Assign(lhs, rhs) => { + self.emit_expr(*lhs); + self.write(" = "); + self.emit_expr(*rhs); + } + ExprKind::AssignOp(op, lhs, rhs) => { + self.emit_expr(*lhs); + self.write(" "); + self.write(binop_as_str(*op)); + self.write("= "); + self.emit_expr(*rhs); + } + ExprKind::AssignField(record, field, value) => { + self.emit_expr(*record); + self.write(" w/= "); + self.emit_field(*record, field); + self.write(" <- "); + self.emit_expr(*value); + } + ExprKind::AssignIndex(array, index, value) => { + self.emit_expr(*array); + self.write(" w/= "); + self.emit_expr(*index); + self.write(" <- "); + self.emit_expr(*value); + } + ExprKind::BinOp(op, lhs, rhs) => { + self.emit_expr(*lhs); + self.write(" "); + self.write(binop_as_str(*op)); + self.write(" "); + self.emit_expr(*rhs); + } + ExprKind::Block(block) => self.emit_block(*block), + ExprKind::Call(callee, arg) => { + self.emit_expr(*callee); + // Argument must be tuple-like to emit as `callee(args)`; for + // non-tuple args, wrap in parens ourselves. + let arg_is_tuple = matches!(self.package().get_expr(*arg).kind, ExprKind::Tuple(_)); + if arg_is_tuple { + self.emit_expr(*arg); + } else { + self.write("("); + self.emit_expr(*arg); + self.write(")"); + } + } + ExprKind::Closure(captures, item) => { + self.write("/* closure item="); + self.write(&format!("{item}")); + self.write(" captures=["); + for (i, local) in captures.iter().enumerate() { + if i > 0 { + self.write(", "); + } + let display = self.local_display(*local); + self.write(&display); + } + self.write("] */ "); + let name = self.callable_name_for(*item); + self.write(&name); + } + ExprKind::Fail(e) => { + self.write("fail "); + self.emit_expr(*e); + } + ExprKind::Field(record, field) => { + self.emit_expr(*record); + self.emit_field(*record, field); + } + ExprKind::Hole => self.write("_"), + ExprKind::If(cond, body, otherwise) => { + self.write("if "); + self.emit_expr(*cond); + self.write(" "); + self.emit_if_branch(*body); + if let Some(e) = otherwise { + if self.mode == RenderMode::Parseable { + self.write(" else "); + if matches!(self.package().get_expr(*e).kind, ExprKind::If(..)) { + self.emit_expr(*e); + } else { + self.emit_if_branch(*e); + } + } else { + let is_elif = matches!(self.package().get_expr(*e).kind, ExprKind::If(..)); + if is_elif { + self.write(" el"); + } else { + self.write(" else "); + } + self.emit_expr(*e); + } + } + } + ExprKind::Index(array, index) => { + self.emit_expr(*array); + self.write("["); + self.emit_expr(*index); + self.write("]"); + } + ExprKind::Lit(lit) => self.emit_lit(lit), + ExprKind::Range(start, step, end) => { + self.emit_range(*start, *step, *end); + } + ExprKind::Return(e) => { + self.write("return "); + self.emit_expr(*e); + } + ExprKind::Struct(res, copy, fields) => { + self.write("new "); + self.emit_res(res); + self.writeln(" {"); + if let Some(c) = copy { + self.write("..."); + self.emit_expr(*c); + if !fields.is_empty() { + self.writeln(","); + } + } + let struct_ty = match res { + Res::Item(_) => Ty::Udt(*res), + _ => Ty::Err, + }; + self.emit_field_assigns(&struct_ty, fields); + self.writeln("}"); + } + ExprKind::String(components) => { + self.write("$\""); + for component in components { + match component { + StringComponent::Expr(e) => { + self.write("{"); + self.emit_expr(*e); + self.write("}"); + } + StringComponent::Lit(s) => self.write(s), + } + } + self.write("\""); + } + ExprKind::Tuple(exprs) => { + self.write("("); + if let Some((last, most)) = exprs.split_last() { + for e in most { + self.emit_expr(*e); + self.write(", "); + } + self.emit_expr(*last); + if most.is_empty() { + self.write(","); + } + } + self.write(")"); + } + ExprKind::UnOp(op, expr) => { + let op_str = unop_as_str(*op); + if matches!(op, UnOp::Unwrap) { + self.emit_expr(*expr); + self.write(op_str); + } else { + self.write(op_str); + self.emit_expr(*expr); + } + } + ExprKind::UpdateField(record, field, value) => { + self.emit_expr(*record); + self.write(" w/ "); + self.emit_field(*record, field); + self.write(" <- "); + self.emit_expr(*value); + } + ExprKind::UpdateIndex(array, index, value) => { + self.emit_expr(*array); + self.write(" w/ "); + self.emit_expr(*index); + self.write(" <- "); + self.emit_expr(*value); + } + ExprKind::Var(res, args) => { + self.emit_res(res); + if !args.is_empty() { + self.write("<"); + for (i, arg) in args.iter().enumerate() { + if i > 0 { + self.write(", "); + } + self.emit_generic_arg(arg); + } + self.write(">"); + } + } + ExprKind::While(cond, block) => { + self.write("while "); + self.emit_expr(*cond); + self.emit_block(*block); + } + } + } + + fn emit_comma_separated_exprs(&mut self, exprs: &[ExprId]) { + if let Some((last, most)) = exprs.split_last() { + for e in most { + self.emit_expr(*e); + self.write(", "); + } + self.emit_expr(*last); + } + } + + fn emit_field_assigns(&mut self, record_ty: &Ty, fields: &[FieldAssign]) { + if let Some((last, most)) = fields.split_last() { + for fa in most { + self.emit_field_assign(record_ty, fa); + self.writeln(","); + } + self.emit_field_assign(record_ty, last); + self.writeln(""); + } + } + + fn emit_field_assign(&mut self, record_ty: &Ty, fa: &FieldAssign) { + let display = self.field_display(record_ty, &fa.field); + // Field::Path renders as "::Name"; strip the leading "::" in struct + // constructor assignments to match idiomatic Q#. + let trimmed = display.strip_prefix("::").unwrap_or(&display); + self.write(trimmed); + self.write(" = "); + self.emit_expr(fa.value); + } + + fn emit_range(&mut self, start: Option, step: Option, end: Option) { + match (start, step, end) { + (None, None, None) => self.write("..."), + (None, None, Some(e)) => { + self.write("..."); + self.emit_expr(e); + } + (None, Some(s), None) => { + self.write("..."); + self.emit_expr(s); + self.write("..."); + } + (None, Some(s), Some(e)) => { + self.write("..."); + self.emit_expr(s); + self.write(".."); + self.emit_expr(e); + } + (Some(s), None, None) => { + self.emit_expr(s); + self.write("..."); + } + (Some(s), None, Some(e)) => { + self.emit_expr(s); + self.write(".."); + self.emit_expr(e); + } + (Some(s), Some(step), None) => { + self.emit_expr(s); + self.write(".."); + self.emit_expr(step); + self.write("..."); + } + (Some(s), Some(step), Some(e)) => { + self.emit_expr(s); + self.write(".."); + self.emit_expr(step); + self.write(".."); + self.emit_expr(e); + } + } + } + + fn emit_lit(&mut self, lit: &Lit) { + match lit { + Lit::BigInt(v) => { + self.write(&v.to_string()); + self.write("L"); + } + Lit::Bool(v) => self.write(if *v { "true" } else { "false" }), + Lit::Double(v) => { + let s = if v.fract() == 0.0 { + format!("{v}.") + } else { + format!("{v}") + }; + self.write(&s); + } + Lit::Int(v) => self.write(&v.to_string()), + Lit::Pauli(p) => self.write(match p { + Pauli::I => "PauliI", + Pauli::X => "PauliX", + Pauli::Y => "PauliY", + Pauli::Z => "PauliZ", + }), + Lit::Result(r) => self.write(match r { + FirResult::Zero => "Zero", + FirResult::One => "One", + }), + } + } + + fn emit_callable_input_pat(&mut self, pat_id: PatId) { + if matches!(self.package().get_pat(pat_id).kind, PatKind::Tuple(_)) { + self.emit_pat(pat_id); + } else { + self.write("("); + self.emit_pat(pat_id); + self.write(")"); + } + } + + fn emit_pat(&mut self, pat_id: PatId) { + let pat = self.package().get_pat(pat_id).clone(); + match pat.kind { + PatKind::Bind(ident) => { + self.write(&self.render_ident(&ident.name)); + self.write(" : "); + self.emit_ty(&pat.ty); + } + PatKind::Discard => { + self.write("_ : "); + self.emit_ty(&pat.ty); + } + PatKind::Tuple(pats) => { + self.write("("); + if let Some((last, most)) = pats.split_last() { + for p in most { + self.emit_pat(*p); + self.write(", "); + } + self.emit_pat(*last); + if most.is_empty() { + self.write(","); + } + } + self.write(")"); + } + } + } + + fn emit_res(&mut self, res: &Res) { + match res { + Res::Err => self.write("/* err */"), + Res::Local(local) => { + let display = self.local_display(*local); + self.write(&display); + } + Res::Item(item_id) => { + let name = self.item_name(*item_id); + self.write(&name); + } + } + } + + fn emit_if_branch(&mut self, expr_id: ExprId) { + if self.mode != RenderMode::Parseable + || matches!(self.package().get_expr(expr_id).kind, ExprKind::Block(_)) + { + self.emit_expr(expr_id); + return; + } + + self.writeln(" {"); + self.emit_expr(expr_id); + self.writeln(""); + self.write("}"); + } + + fn emit_pat_bindings(&mut self, pat_id: PatId) { + let pat = self.package().get_pat(pat_id).clone(); + match pat.kind { + PatKind::Bind(ident) => self.write(&self.render_ident(&ident.name)), + PatKind::Discard => self.write("_"), + PatKind::Tuple(pats) => { + self.write("("); + if let Some((last, most)) = pats.split_last() { + for p in most { + self.emit_pat_bindings(*p); + self.write(", "); + } + self.emit_pat_bindings(*last); + if most.is_empty() { + self.write(","); + } + } + self.write(")"); + } + } + } + + fn render_ident(&self, name: &str) -> String { + if self.mode != RenderMode::Parseable { + return name.to_string(); + } + + let mut rendered = String::with_capacity(name.len()); + for (index, ch) in name.chars().enumerate() { + let is_valid = if index == 0 { + ch == '_' || ch.is_ascii_alphabetic() + } else { + ch == '_' || ch.is_ascii_alphanumeric() + }; + rendered.push(if is_valid { ch } else { '_' }); + } + if rendered.is_empty() { + rendered.push('_'); + } + rendered + } + + fn local_display(&self, local: LocalVarId) -> String { + match self.local_names.get(&local) { + Some(name) => name.to_string(), + None => format!("_local{local}"), + } + } + + fn callable_name_for(&self, item: LocalItemId) -> String { + let pkg = self.package(); + match &pkg.get_item(item).kind { + ItemKind::Callable(decl) => decl.name.name.to_string(), + ItemKind::Ty(name, _) => name.name.to_string(), + _ => format!("Item({item})"), + } + } + + fn item_name(&self, item_id: ItemId) -> String { + if item_id.package == self.package_id { + self.callable_name_for(item_id.item) + } else { + let store_id = StoreItemId { + package: item_id.package, + item: item_id.item, + }; + match &self.store.get_item(store_id).kind { + ItemKind::Callable(decl) => decl.name.name.to_string(), + ItemKind::Ty(name, _) => name.name.to_string(), + _ => format!("{item_id}"), + } + } + } + + fn emit_field(&mut self, record: ExprId, field: &Field) { + let record_ty = self.package().get_expr(record).ty.clone(); + let display = self.field_display(&record_ty, field); + self.write(&display); + } + + fn field_display(&self, record_ty: &Ty, field: &Field) -> String { + match field { + Field::Err => "::/* err */".to_string(), + Field::Prim(prim) => match prim { + PrimField::Start => "::Start".to_string(), + PrimField::Step => "::Step".to_string(), + PrimField::End => "::End".to_string(), + }, + Field::Path(path) => self.resolve_field_path(record_ty, path), + } + } + + fn resolve_field_path(&self, record_ty: &Ty, path: &FieldPath) -> String { + if let Some(udt) = self.lookup_udt(record_ty) + && let Some(name) = udt_field_name(udt, path) + { + return format!("::{name}"); + } + let mut out = String::new(); + for idx in &path.indices { + let _ = write!(out, "::Item<{idx}>"); + } + out + } + + fn lookup_udt(&self, ty: &Ty) -> Option<&Udt> { + let Ty::Udt(Res::Item(item_id)) = ty else { + return None; + }; + let store_id = StoreItemId { + package: item_id.package, + item: item_id.item, + }; + let item = self.store.get_item(store_id); + match &item.kind { + ItemKind::Ty(_, udt) => Some(udt), + _ => None, + } + } + + fn emit_ty(&mut self, ty: &Ty) { + self.write(&ty_as_qsharp(ty)); + } + + fn emit_generic_arg(&mut self, arg: &GenericArg) { + match arg { + GenericArg::Ty(ty) => self.emit_ty(ty), + GenericArg::Functor(FunctorSet::Value(fsv)) => { + self.write(functor_set_value_as_str(*fsv)); + } + GenericArg::Functor(FunctorSet::Param(p)) => { + self.write(&format!("functor<{p}>")); + } + GenericArg::Functor(FunctorSet::Infer(_)) => { + self.write("functor"); + } + } + } +} + +fn binop_as_str(op: BinOp) -> &'static str { + match op { + BinOp::Add => "+", + BinOp::AndB => "&&&", + BinOp::AndL => "and", + BinOp::Div => "/", + BinOp::Eq => "==", + BinOp::Exp => "^", + BinOp::Gt => ">", + BinOp::Gte => ">=", + BinOp::Lt => "<", + BinOp::Lte => "<=", + BinOp::Mod => "%", + BinOp::Mul => "*", + BinOp::Neq => "!=", + BinOp::OrB => "|||", + BinOp::OrL => "or", + BinOp::Shl => "<<<", + BinOp::Shr => ">>>", + BinOp::Sub => "-", + BinOp::XorB => "^^^", + } +} + +fn unop_as_str(op: UnOp) -> &'static str { + match op { + UnOp::Functor(Functor::Adj) => "Adjoint ", + UnOp::Functor(Functor::Ctl) => "Controlled ", + UnOp::Neg => "-", + UnOp::NotB => "~~~", + UnOp::NotL => "not ", + UnOp::Pos => "+", + UnOp::Unwrap => "!", + } +} + +fn functor_set_value_as_str(fsv: FunctorSetValue) -> &'static str { + match fsv { + FunctorSetValue::Empty => "()", + FunctorSetValue::Adj => "Adj", + FunctorSetValue::Ctl => "Ctl", + FunctorSetValue::CtlAdj => "Adj + Ctl", + } +} + +fn prim_as_qsharp(prim: Prim) -> &'static str { + match prim { + Prim::BigInt => "BigInt", + Prim::Bool => "Bool", + Prim::Double => "Double", + Prim::Int => "Int", + Prim::Pauli => "Pauli", + Prim::Qubit => "Qubit", + Prim::Range | Prim::RangeTo | Prim::RangeFrom | Prim::RangeFull => "Range", + Prim::Result => "Result", + Prim::String => "String", + } +} + +fn ty_as_qsharp(ty: &Ty) -> String { + match ty { + Ty::Array(item) => format!("{}[]", ty_as_qsharp(item)), + Ty::Arrow(arrow) => arrow_as_qsharp(arrow), + Ty::Infer(_) => "_".to_string(), + Ty::Param(p) => format!("'T{p}"), + Ty::Prim(p) => prim_as_qsharp(*p).to_string(), + Ty::Tuple(items) => { + if items.is_empty() { + "Unit".to_string() + } else if items.len() == 1 { + format!("({},)", ty_as_qsharp(&items[0])) + } else { + let parts: Vec<_> = items.iter().map(ty_as_qsharp).collect(); + format!("({})", parts.join(", ")) + } + } + Ty::Udt(Res::Item(item_id)) => format!("UDT<{item_id}>"), + Ty::Udt(Res::Local(local)) => format!("UDT"), + Ty::Udt(Res::Err) => "UDT".to_string(), + Ty::Err => "?".to_string(), + } +} + +fn arrow_as_qsharp(arrow: &Arrow) -> String { + let sep = match arrow.kind { + CallableKind::Function => "->", + CallableKind::Operation => "=>", + }; + let input = ty_as_qsharp(&arrow.input); + let output = ty_as_qsharp(&arrow.output); + match arrow.functors { + FunctorSet::Value(FunctorSetValue::Empty) => format!("({input} {sep} {output})"), + FunctorSet::Value(v) => format!( + "({input} {sep} {output} is {})", + functor_set_value_as_str(v) + ), + FunctorSet::Param(p) => format!("({input} {sep} {output} is functor<{p}>)"), + FunctorSet::Infer(_) => format!("({input} {sep} {output} is functor)"), + } +} + +fn type_parameter_name(p: &TypeParameter) -> String { + match p { + TypeParameter::Ty { name, .. } => format!("'{name}"), + TypeParameter::Functor(fsv) => format!("functor<{}>", functor_set_value_as_str(*fsv)), + } +} + +fn udt_field_name(udt: &Udt, path: &FieldPath) -> Option> { + use qsc_fir::ty::UdtDefKind; + let mut def = &udt.definition; + for &index in &path.indices { + match &def.kind { + UdtDefKind::Tuple(items) => { + def = items.get(index)?; + } + UdtDefKind::Field(_) => return None, + } + } + match &def.kind { + UdtDefKind::Field(f) => f.name.clone(), + UdtDefKind::Tuple(_) => None, + } +} diff --git a/source/compiler/qsc_fir_transforms/src/pretty/tests.rs b/source/compiler/qsc_fir_transforms/src/pretty/tests.rs new file mode 100644 index 0000000000..36a559c5e8 --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/pretty/tests.rs @@ -0,0 +1,258 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use super::*; +use crate::test_utils::{PipelineStage, compile_and_run_pipeline_to}; +use expect_test::{Expect, expect}; +use indoc::indoc; + +fn render_after_mono(source: &str) -> String { + let (store, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::Mono); + write_package_qsharp(&store, pkg_id) +} + +fn check_render(source: &str, expect: &Expect) { + expect.assert_eq(&render_after_mono(source)); +} + +#[test] +fn simple_function_renders() { + check_render( + indoc! {r#" + namespace Test { + function Add(a : Int, b : Int) : Int { + a + b + } + @EntryPoint() + function Main() : Int { + Add(1, 2) + } + } + "#}, + &expect![[r#" + // namespace Test + function Add(a : Int, b : Int) : Int { + body { + a + b + } + } + function Main() : Int { + body { + Add(1, 2) + } + } + // entry + Main() + "#]], // snapshot populated by UPDATE_EXPECT=1 + ); +} + +#[test] +fn operation_with_specializations_renders() { + check_render( + indoc! {r#" + namespace Test { + operation Op(q : Qubit) : Unit is Adj + Ctl { + body ... { X(q); } + adjoint ... { X(q); } + controlled (ctls, ...) { Controlled X(ctls, q); } + controlled adjoint (ctls, ...) { Controlled X(ctls, q); } + } + @EntryPoint() + operation Main() : Unit { + use q = Qubit(); + Op(q); + } + } + "#}, + &expect![[r#" + // namespace Test + operation Op(q : Qubit) : Unit is Adj + Ctl { + body { + X(q); + } + adjoint { + X(q); + } + controlled { + Controlled X(_local2, q); + } + controlled adjoint { + Controlled X(_local3, q); + } + } + operation Main() : Unit { + body { + let q : Qubit = __quantum__rt__qubit_allocate(); + Op(q); + __quantum__rt__qubit_release(q); + } + } + function Length(a : Qubit[]) : Int { + body intrinsic; + } + // entry + Main() + "#]], // snapshot populated by UPDATE_EXPECT=1 + ); +} + +#[test] +fn nested_block_renders() { + check_render( + indoc! {r#" + namespace Test { + @EntryPoint() + function Main() : Int { + let x = { + let y = 1; + y + 2 + }; + x + } + } + "#}, + &expect![[r#" + // namespace Test + function Main() : Int { + body { + let x : Int = { + let y : Int = 1; + y + 2 + }; + x + } + } + // entry + Main() + "#]], // snapshot populated by UPDATE_EXPECT=1 + ); +} + +#[test] +fn common_expr_kinds_render() { + check_render( + indoc! {r#" + namespace Test { + @EntryPoint() + function Main() : Int { + mutable arr = [1, 2, 3]; + arr w/= 0 <- 42; + let r = arr w/ 1 <- 99; + let tup = (1, 2, 3); + let s = $"value is {tup}"; + if arr[0] > 0 { + arr[0] + } else { + -1 + } + } + } + "#}, + &expect![[r#" + // namespace Test + function Main() : Int { + body { + mutable arr : Int[] = [1, 2, 3]; + arr w/= 0 <- 42; + let r : Int[] = arr w/ 1 <- 99; + let tup : (Int, Int, Int) = (1, 2, 3); + let s : String = $"value is {tup}"; + if arr[0] > 0 { + arr[0] + } else { + -1 + } + + } + } + // entry + Main() + "#]], // snapshot populated by UPDATE_EXPECT=1 + ); +} + +#[test] +fn udt_field_renders_by_name_when_available() { + check_render( + indoc! {r#" + namespace Test { + newtype Pair = (First : Int, Second : Int); + @EntryPoint() + function Main() : Int { + let p = Pair(1, 2); + p::First + } + } + "#}, + &expect![[r#" + // namespace Test + newtype Pair = (Int, Int); + function Main() : Int { + body { + let p : UDT < Item 1(Package 2) > = Pair(1, 2); + p::First + } + } + // entry + Main() + "#]], // snapshot populated by UPDATE_EXPECT=1 + ); +} + +#[test] +fn write_expr_renders_expression() { + let src = indoc! {r#" + namespace Test { + @EntryPoint() + function Main() : Int { + 1 + 2 + } + } + "#}; + let (store, pkg_id) = compile_and_run_pipeline_to(src, PipelineStage::Mono); + let pkg = store.get(pkg_id); + let mut found = None; + for item in pkg.items.values() { + if let ItemKind::Callable(decl) = &item.kind + && decl.name.name.as_ref() == "Main" + && let CallableImpl::Spec(spec) = &decl.implementation + { + let block = pkg.get_block(spec.body.block); + if let Some(&stmt_id) = block.stmts.first() { + let stmt = pkg.get_stmt(stmt_id); + if let StmtKind::Expr(e) | StmtKind::Semi(e) = &stmt.kind { + found = Some(*e); + } + } + } + } + let expr_id = found.expect("Main body has a trailing expression"); + let rendered = write_expr_qsharp(&store, pkg_id, expr_id); + expect!["1 + 2"] // snapshot populated by UPDATE_EXPECT=1 + .assert_eq(&rendered); +} + +#[test] +fn binop_as_str_covers_representative_variants() { + assert_eq!(binop_as_str(BinOp::Add), "+"); + assert_eq!(binop_as_str(BinOp::AndL), "and"); + assert_eq!(binop_as_str(BinOp::Shl), "<<<"); +} + +#[test] +fn unop_as_str_covers_functors() { + assert_eq!(unop_as_str(UnOp::Functor(Functor::Adj)), "Adjoint "); + assert_eq!(unop_as_str(UnOp::Functor(Functor::Ctl)), "Controlled "); + assert_eq!(unop_as_str(UnOp::Unwrap), "!"); +} + +#[test] +fn ty_rendering_handles_primitives_and_tuples() { + assert_eq!(ty_as_qsharp(&Ty::Prim(Prim::Int)), "Int"); + assert_eq!(ty_as_qsharp(&Ty::Tuple(Vec::new())), "Unit"); + assert_eq!( + ty_as_qsharp(&Ty::Array(Box::new(Ty::Prim(Prim::Bool)))), + "Bool[]" + ); +} diff --git a/source/compiler/qsc_fir_transforms/src/reachability.rs b/source/compiler/qsc_fir_transforms/src/reachability.rs new file mode 100644 index 0000000000..9594daa618 --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/reachability.rs @@ -0,0 +1,203 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! Entry-rooted call graph walker. +//! +//! [`collect_reachable_from_entry`] starts from a package's entry expression +//! and transitively discovers every callable item reachable through the FIR +//! call graph, including cross-package references. +//! +//! The algorithm is a worklist-based breadth-first walk. Starting from the +//! entry expression, it follows every `Res::Item` reference encountered in +//! expression trees, adding newly discovered +//! callables to the worklist until a fixed point is reached. +//! +//! [`collect_reachable_with_seeds`] extends this by accepting additional +//! pinned items as extra roots alongside the entry expression. +//! +//! [`collect_reachable_package_closure`] computes the cross-package +//! reachability closure needed by UDT erasure to determine which packages +//! require type-item removal. + +#[cfg(test)] +mod tests; + +use qsc_fir::fir::{CallableImpl, ExprKind, ItemKind, PackageId, PackageStore, Res, StoreItemId}; +use rustc_hash::FxHashSet; + +/// Returns the set of all callable items transitively reachable from the entry +/// expression of the given package. +/// +/// Cross-package references are followed, so the result may contain items from +/// library packages. Intrinsic callables are included as reachable (they have +/// no body to walk but are still referenced). +/// +/// # Scoping contract +/// +/// - **Missing items are silently skipped.** Interpreter entry expressions +/// can carry runtime-unbound item references that survive a rejected +/// callable definition. When the worklist encounters a `StoreItemId` that +/// no longer exists in its package's item table, the walker drops it and +/// continues; later evaluation reports the diagnostic instead of failing +/// here. +/// - **Closures resolve in the current package only.** +/// [`ExprKind::Closure(_, local_item_id)`](ExprKind::Closure) carries a +/// bare [`LocalItemId`](qsc_fir::fir::LocalItemId); the walker pairs it +/// with the *containing* package id rather than any source package id. As +/// a result closures cannot point outside the package in which they +/// appear, and the walker treats them accordingly. +/// +/// # Panics +/// +/// Panics if the package has no entry expression. +#[must_use] +pub fn collect_reachable_from_entry( + store: &PackageStore, + package_id: PackageId, +) -> FxHashSet { + let package = store.get(package_id); + let entry_expr_id = package + .entry + .expect("package must have an entry expression"); + + let mut visited = FxHashSet::default(); + let mut worklist: Vec = Vec::new(); + + walk_expr(store, package_id, entry_expr_id, &mut worklist); + + while let Some(item_id) = worklist.pop() { + if visited.contains(&item_id) { + continue; + } + let item_pkg = store.get(item_id.package); + let Some(item) = item_pkg.items.get(item_id.item) else { + // Interpreter entry expressions can carry runtime-unbound item references + // after a rejected callable definition. Leave those for later evaluation + // diagnostics instead of panicking during reachability discovery. + continue; + }; + visited.insert(item_id); + if let ItemKind::Callable(decl) = &item.kind { + walk_callable_impl(store, item_id.package, &decl.implementation, &mut worklist); + } + } + + visited +} + +/// Returns the set of all callable items transitively reachable from the +/// entry expression **and** from the additional `seeds`. +/// +/// Seeds are added to the worklist alongside the items discovered from the +/// entry expression, so their transitive dependencies are also included in +/// the output set. +/// +/// # Panics +/// +/// Panics if the package has no entry expression. +#[must_use] +pub fn collect_reachable_with_seeds( + store: &PackageStore, + package_id: PackageId, + seeds: &[StoreItemId], +) -> FxHashSet { + let package = store.get(package_id); + let entry_expr_id = package + .entry + .expect("package must have an entry expression"); + + let mut visited = FxHashSet::default(); + let mut worklist: Vec = seeds.to_vec(); + + walk_expr(store, package_id, entry_expr_id, &mut worklist); + + while let Some(item_id) = worklist.pop() { + if visited.contains(&item_id) { + continue; + } + let item_pkg = store.get(item_id.package); + let Some(item) = item_pkg.items.get(item_id.item) else { + continue; + }; + visited.insert(item_id); + if let ItemKind::Callable(decl) = &item.kind { + walk_callable_impl(store, item_id.package, &decl.implementation, &mut worklist); + } + } + + visited +} + +/// Returns the package closure induced by an entry-reachable callable set. +/// +/// The returned set always includes the root package, even when the entry +/// expression reaches no other callables. +#[must_use] +pub fn collect_reachable_package_closure<'a>( + package_id: PackageId, + reachable: impl IntoIterator, +) -> FxHashSet { + let mut packages = FxHashSet::default(); + packages.insert(package_id); + packages.extend(reachable.into_iter().map(|item_id| item_id.package)); + packages +} + +/// Convenience wrapper around [`collect_reachable_from_entry`] and +/// [`collect_reachable_package_closure`]. +/// +/// # Panics +/// +/// Panics if the package has no entry expression. +#[must_use] +pub fn collect_reachable_package_closure_from_entry( + store: &PackageStore, + package_id: PackageId, +) -> FxHashSet { + let reachable = collect_reachable_from_entry(store, package_id); + collect_reachable_package_closure(package_id, &reachable) +} + +/// Walks the bodies of a callable implementation, enqueueing every referenced +/// item onto `worklist`. Closures enqueue `(pkg_id, local_item_id)` because +/// `ExprKind::Closure` always resolves within the containing package. +fn walk_callable_impl( + store: &PackageStore, + pkg_id: PackageId, + callable_impl: &CallableImpl, + worklist: &mut Vec, +) { + let pkg = store.get(pkg_id); + crate::walk_utils::for_each_expr_in_callable_impl(pkg, callable_impl, &mut |_eid, expr| { + match &expr.kind { + ExprKind::Var(Res::Item(item_id), _) => { + worklist.push(StoreItemId::from((item_id.package, item_id.item))); + } + ExprKind::Closure(_, local_item_id) => { + worklist.push(StoreItemId::from((pkg_id, *local_item_id))); + } + _ => {} + } + }); +} + +/// Walks the expression subtree rooted at `expr_id`, enqueueing every +/// referenced item onto `worklist`. Mirrors the closure scoping rule in +/// [`walk_callable_impl`]. +fn walk_expr( + store: &PackageStore, + pkg_id: PackageId, + expr_id: qsc_fir::fir::ExprId, + worklist: &mut Vec, +) { + let pkg = store.get(pkg_id); + crate::walk_utils::for_each_expr(pkg, expr_id, &mut |_eid, expr| match &expr.kind { + ExprKind::Var(Res::Item(item_id), _) => { + worklist.push(StoreItemId::from((item_id.package, item_id.item))); + } + ExprKind::Closure(_, local_item_id) => { + worklist.push(StoreItemId::from((pkg_id, *local_item_id))); + } + _ => {} + }); +} diff --git a/source/compiler/qsc_fir_transforms/src/reachability/tests.rs b/source/compiler/qsc_fir_transforms/src/reachability/tests.rs new file mode 100644 index 0000000000..852d917288 --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/reachability/tests.rs @@ -0,0 +1,391 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use super::*; +use expect_test::{Expect, expect}; +use indoc::indoc; +use qsc_fir::fir::PackageLookup; + +/// Compiles Q# source, runs reachability analysis, and returns a sorted +/// list of reachable callable names from the user package. +fn extract_reachable(source: &str) -> String { + let (store, pkg_id) = crate::test_utils::compile_to_fir(source); + let reachable = collect_reachable_from_entry(&store, pkg_id); + let package = store.get(pkg_id); + let mut names: Vec = Vec::new(); + for store_id in &reachable { + if store_id.package != pkg_id { + continue; + } + let item = package.get_item(store_id.item); + if let ItemKind::Callable(decl) = &item.kind { + names.push(decl.name.name.to_string()); + } + } + names.sort(); + names.join("\n") +} + +fn check(source: &str, expect: &Expect) { + expect.assert_eq(&extract_reachable(source)); +} + +#[test] +fn reachable_includes_direct_call_chain() { + // Main calls A, A calls B. C is never called. + check( + indoc! {" + namespace Test { + function B() : Unit {} + function A() : Unit { B(); } + function C() : Unit {} + @EntryPoint() + function Main() : Unit { A(); } + } + "}, + &expect![[r#" + A + B + Main"#]], + ); +} + +#[test] +fn unreachable_callable_excluded() { + // Only Main is called; Orphan is unreachable. + check( + indoc! {" + namespace Test { + function Orphan() : Unit {} + @EntryPoint() + function Main() : Unit {} + } + "}, + &expect![[r#" + Main"#]], + ); +} + +#[test] +fn transitive_chain_all_reachable() { + // Main → A → B → C (full chain). + check( + indoc! {" + namespace Test { + function C() : Unit {} + function B() : Unit { C(); } + function A() : Unit { B(); } + @EntryPoint() + function Main() : Unit { A(); } + } + "}, + &expect![[r#" + A + B + C + Main"#]], + ); +} + +#[test] +fn diamond_call_graph() { + // Main → A and Main → B, both call Leaf. + check( + indoc! {" + namespace Test { + function Leaf() : Unit {} + function A() : Unit { Leaf(); } + function B() : Unit { Leaf(); } + @EntryPoint() + function Main() : Unit { A(); B(); } + } + "}, + &expect![[r#" + A + B + Leaf + Main"#]], + ); +} + +#[test] +fn multiple_unreachable_functions() { + check( + indoc! {" + namespace Test { + function Dead1() : Unit {} + function Dead2() : Unit {} + function Alive() : Unit {} + @EntryPoint() + function Main() : Unit { Alive(); } + } + "}, + &expect![[r#" + Alive + Main"#]], + ); +} + +#[test] +fn entry_expression_followed() { + // A single entry point with no calls — only Main is reachable. + check( + indoc! {" + namespace Test { + function Helper() : Unit {} + @EntryPoint() + function Main() : Int { 42 } + } + "}, + &expect![[r#" + Main"#]], + ); +} + +#[test] +fn closure_inside_reachable_callable_followed() { + // A closure defined inside a reachable callable — the callable + // that the closure targets should also be reachable. + check( + indoc! {" + namespace Test { + @EntryPoint() + function Main() : Int { + let f = (x) -> x + 1; + f(5) + } + } + "}, + &expect![[r#" + + Main"#]], + ); +} + +#[test] +fn recursive_callable_reachable() { + // Recursive callable: Recurse calls itself. + check( + indoc! {" + namespace Test { + function Recurse(n : Int) : Int { + if n <= 0 { 0 } else { Recurse(n - 1) } + } + @EntryPoint() + function Main() : Int { Recurse(5) } + } + "}, + &expect![[r#" + Main + Recurse"#]], + ); +} + +#[test] +fn mutually_recursive_callables_reachable() { + // Mutual recursion: Ping calls Pong, Pong calls Ping. + check( + indoc! {" + namespace Test { + function Ping(n : Int) : Int { + if n <= 0 { 0 } else { Pong(n - 1) } + } + function Pong(n : Int) : Int { Ping(n) } + @EntryPoint() + function Main() : Int { Ping(3) } + } + "}, + &expect![[r#" + Main + Ping + Pong"#]], + ); +} + +#[test] +fn callable_only_in_unreachable_branch() { + // A call inside a conditional branch that is syntactically present + // but the function is still reachable because we do static analysis. + check( + indoc! {" + namespace Test { + function DeadEnd() : Unit {} + @EntryPoint() + function Main() : Unit { + if false { DeadEnd(); } + } + } + "}, + &expect![[r#" + DeadEnd + Main"#]], + ); +} + +#[test] +fn lambda_in_entry_expression() { + // Lambda defined and invoked directly in the entry expression. + check( + indoc! {" + namespace Test { + @EntryPoint() + function Main() : Int { + let add = (a, b) -> a + b; + add(3, 4) + } + } + "}, + &expect![[r#" + + Main"#]], + ); +} + +#[test] +fn cross_package_call_reachability_scoped_to_package() { + // Calling a stdlib function from the user package. The reachable set + // for the user package should include Main but should not include + // any stdlib callable (reachability returns StoreItemIds across + // packages, but our helper `extract_reachable` filters to user-package + // callables only). + check( + indoc! {" + namespace Test { + @EntryPoint() + function Main() : Int { + Microsoft.Quantum.Math.MaxI(1, 2) + } + } + "}, + &expect![[r#" + Main"#]], + ); +} + +#[test] +fn simulatable_intrinsic_callable_reachable() { + // An operation with @SimulatableIntrinsic() should appear in the + // reachable set when called from an entry point. + check( + indoc! {" + namespace Test { + @SimulatableIntrinsic() + operation MyOp() : Unit { + body intrinsic; + } + @EntryPoint() + operation Main() : Unit { + MyOp(); + } + } + "}, + &expect![[r#" + Main + MyOp"#]], + ); +} + +#[test] +fn dangling_item_reference_is_ignored() { + let (mut store, pkg_id) = crate::test_utils::compile_to_fir(indoc! {" + namespace Test { + function Helper() : Unit {} + @EntryPoint() + function Main() : Unit { + Helper(); + } + } + "}); + + let package = store.get(pkg_id); + let main_id = package + .items + .values() + .find_map(|item| match &item.kind { + ItemKind::Callable(decl) if decl.name.name.as_ref() == "Main" => Some(item.id), + _ => None, + }) + .expect("Main should exist"); + let helper_id = package + .items + .values() + .find_map(|item| match &item.kind { + ItemKind::Callable(decl) if decl.name.name.as_ref() == "Helper" => Some(item.id), + _ => None, + }) + .expect("Helper should exist"); + + store.get_mut(pkg_id).items.remove(helper_id); + + let reachable = collect_reachable_from_entry(&store, pkg_id); + assert!(reachable.contains(&StoreItemId::from((pkg_id, main_id)))); + assert!(!reachable.contains(&StoreItemId::from((pkg_id, helper_id)))); +} + +#[test] +fn seeds_include_transitive_deps_unreachable_from_entry() { + let (store, pkg_id) = crate::test_utils::compile_to_fir(indoc! {" + namespace Test { + function Helper() : Unit {} + function Unreachable() : Unit { Helper(); } + @EntryPoint() + function Main() : Unit {} + } + "}); + + let package = store.get(pkg_id); + + let find_callable = |name: &str| -> StoreItemId { + let local_id = package + .items + .values() + .find_map(|item| match &item.kind { + ItemKind::Callable(decl) if decl.name.name.as_ref() == name => Some(item.id), + _ => None, + }) + .unwrap_or_else(|| panic!("{name} should exist")); + StoreItemId::from((pkg_id, local_id)) + }; + + let unreachable_id = find_callable("Unreachable"); + let helper_id = find_callable("Helper"); + + // Baseline: neither Unreachable nor Helper is reachable from entry. + let entry_only = collect_reachable_from_entry(&store, pkg_id); + assert!( + !entry_only.contains(&unreachable_id), + "Unreachable should not be in the entry-only set" + ); + assert!( + !entry_only.contains(&helper_id), + "Helper should not be in the entry-only set" + ); + + // With Unreachable as a seed, both it and its transitive dep Helper + // should appear. + let seeded = collect_reachable_with_seeds(&store, pkg_id, &[unreachable_id]); + assert!( + seeded.contains(&unreachable_id), + "seed callable should be in the seeded set" + ); + assert!( + seeded.contains(&helper_id), + "transitive dep of seed should be in the seeded set" + ); +} + +#[test] +fn reachability_is_idempotent() { + let source = indoc! {" + namespace Test { + function Helper() : Unit {} + function Dead() : Unit {} + @EntryPoint() + function Main() : Unit { Helper(); } + } + "}; + let (store, pkg_id) = crate::test_utils::compile_to_fir(source); + let first = collect_reachable_from_entry(&store, pkg_id); + let second = collect_reachable_from_entry(&store, pkg_id); + assert_eq!(first, second, "reachability analysis should be idempotent"); +} diff --git a/source/compiler/qsc_fir_transforms/src/return_unify.rs b/source/compiler/qsc_fir_transforms/src/return_unify.rs new file mode 100644 index 0000000000..cab4ddc11d --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/return_unify.rs @@ -0,0 +1,3937 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! Return unification pass. +//! +//! Eliminates all `ExprKind::Return` nodes from callable bodies, ensuring +//! every callable has exactly one exit point — the trailing expression of its +//! top-level block. +//! +//! Establishes [`crate::invariants::InvariantLevel::PostReturnUnify`] +//! (additionally) on top of [`crate::invariants::InvariantLevel::PostMono`]: +//! no `ExprKind::Return` remains in reachable code. +//! +//! # Pipeline position +//! +//! This pass runs after monomorphization (types are concrete) and before +//! defunctionalization. Synthesized expressions use `EMPTY_EXEC_RANGE`; the +//! [`crate::exec_graph_rebuild`] pass rebuilds correct exec graphs afterward. +//! See [`crate::run_pipeline_to_impl`] for the full ordering. +//! +//! # Architecture +//! +//! The pass uses a four-phase pipeline per callable block: +//! +//! 1. **Normalize** ([`normalize::hoist_returns_to_statement_boundary`]): +//! Hoist any `Return` in compound positions (e.g. inside a block-expression +//! used as a `Call` argument) to its enclosing statement boundary. After +//! this phase, every `Return` is either a bare `Semi(Return(_))` / +//! `Expr(Return(_))` or nested inside `If`, `While`, or `Block` statements. +//! +//! 2. **Dispatch** ([`should_use_flag_strategy`]): +//! Classify the block into one of the dispatch categories below and select +//! the appropriate transform strategy. +//! +//! 3. **Transform** ([`transform_block_if_else`] or +//! [`transform_block_with_flags`]): +//! Apply the selected strategy to eliminate all `Return` nodes. +//! +//! 4. **Simplify** ([`simplify_flag_patterns`]): +//! After the flag strategy, fold trivial identity patterns such as +//! `if __has_returned { v } else { v }` → `v`. This is the structured-IR +//! analog of LLVM's `SimplifyCFG` after `mergereturn`. +//! +//! ## Strategies +//! +//! 1. **If-else lifting** (primary, [`transform_block_if_else`]): +//! Restructures blocks containing returns into nested if-else expressions. +//! Handles guard clauses and branching returns without introducing mutable +//! state. Selected for category-A shapes. +//! +//! 2. **Flag-based transform** (fallback, [`transform_block_with_flags`]): +//! Introduces `__has_returned` and `__ret_val` mutable locals to handle +//! returns inside while loops, leaky nested-if patterns, and block-init +//! returns. Selected for category-B, -C, and -D shapes. +//! +//! # Input patterns +//! +//! - `Return(value)` appearing inside conditional or loop blocks. +//! +//! # Rewrites +//! +//! Flag-based rewrite of a return inside a while loop: +//! +//! ```text +//! // Before +//! mutable r = 0; +//! while cond { +//! if done { return r; } +//! r += 1; +//! } +//! +//! // After +//! mutable __has_returned = false; +//! mutable __ret_val = 0; +//! mutable r = 0; +//! while not __has_returned and cond { +//! if done { +//! __ret_val = r; +//! __has_returned = true; +//! } else { +//! r += 1; +//! } +//! } +//! if __has_returned { __ret_val } else { () } +//! ``` +//! +//! # Dispatch policy +//! +//! The function [`should_use_flag_strategy`] is the single dispatch point +//! that decides, per callable block, whether to use the structured if-else +//! lifting strategy or fall back to the flag-based transform. +//! +//! Fallback detection is driven by [`contains_return_in_while`] and +//! [`contains_leaky_early_return`], with nested statement/expression scans +//! delegated to [`crate::return_unify::detect::contains_return_in_expr`] and +//! its block-level companion. +//! +//! The dispatch categories recognized today are: +//! +//! * **Category A — guard clauses and pure `if`/`else` nests.** Returns +//! appear only on conditional branches outside `while` bodies and do not +//! hit the leaky nested-if shape; the structured strategy lifts them into +//! nested if-else expressions. +//! * **Category B — returns inside while loops.** Any `Return` reachable in a +//! `while` body causes [`should_use_flag_strategy`] to select the flag-based +//! fallback. +//! * **Category C — leaky nested-if early returns.** A `Return` under an +//! if-without-else chain at depth >= 2 causes +//! [`should_use_flag_strategy`] to select the flag-based fallback. +//! * **Category D — returns inside block-expression Local initializers.** +//! A `Return` inside a `Block` expression used as a `Local` initializer +//! causes [`should_use_flag_strategy`] to select the flag-based fallback +//! (detected by [`contains_return_in_block_init`]). +//! Non-block `Local`-initializer `MayReturn` shapes stay on the structured +//! path and are rewritten by [`transform_local_init`] via +//! [`decompose_returning_init`]. +//! +//! The flag-strategy fallback is modeled on the LLVM lowering pattern for +//! early returns (cf. LLVM `UnifyFunctionExitNodes` / `mergereturn`): a +//! synthesized `__has_returned` slot plus a merge block guard the remainder +//! of the loop body, preserving the semantics of the original early exit +//! when category-B, -C, or -D shape makes the structured lowering unsound. +//! +//! # Invariant contracts +//! +//! After this pass completes, the following invariants hold: +//! +//! * **No `Return` nodes** — checked by +//! `crate::invariants::check_no_returns`. Every `ExprKind::Return` in +//! reachable code must be eliminated; any surviving `Return` triggers a +//! hard assertion failure. +//! * **Non-Unit block tails** — checked by +//! `crate::invariants::check_non_unit_block_tails`. Every block whose +//! type is not `Unit` must end with a `StmtKind::Expr` (not `Semi`), +//! ensuring downstream code generation sees a value-producing tail. +//! +//! These invariants are verified at +//! [`crate::invariants::InvariantLevel::PostReturnUnify`] by the pipeline +//! runner after this pass returns. +//! +//! # Error reporting +//! +//! [`unify_returns`] returns `Vec` rather than panicking. The known +//! user-reachable error is [`Error::UnsupportedLoopReturnType`]: the flag +//! strategy requires a classical default for `__ret_val`, but types like +//! `Qubit` have no classical default. This is caught by +//! [`can_create_classical_default`] before entering the transform, producing +//! a user-facing diagnostic. Processing continues for remaining callables. +//! +//! # Qubit release interaction +//! +//! This pass does not classify or hoist release calls. Release operations are +//! treated as ordinary side effects, and correctness comes from preserving +//! control-flow reachability while eliminating `Return` nodes: +//! +//! - In structured `if` rewrites, continuation statements (including releases) +//! are moved only to fallthrough paths. +//! - For `if` where both branches always return, trailing continuation is dead +//! and removed. +//! - For `Local` initializer `MayReturn` shapes, the pass decomposes the init +//! into an outer guard statement ([`decompose_returning_init`]) so the +//! continuation executes only on fallthrough. +//! +//! This avoids dedicated release-shape analysis while still preventing +//! path-duplication bugs such as double release. +//! +//! Extension: to add a new category, widen [`should_use_flag_strategy`] +//! and extend the test matrix under `return_unify/normalize/tests.rs`. + +mod detect; +mod normalize; + +#[cfg(test)] +mod tests; + +#[cfg(all(test, feature = "slow-proptest-tests"))] +mod semantic_equivalence_tests; + +use crate::fir_builder::{ + alloc_assign_expr, alloc_bin_op_expr, alloc_block, alloc_block_expr, alloc_bool_lit, + alloc_expr_stmt, alloc_if_expr, alloc_local_var, alloc_local_var_expr, alloc_not_expr, + alloc_semi_stmt, alloc_unit_expr, functored_specs, +}; +use miette::Diagnostic; +use num_bigint::BigInt; +use qsc_data_structures::span::Span; +use qsc_fir::{ + assigner::Assigner, + fir::{ + BinOp, BlockId, CallableDecl, CallableImpl, Expr, ExprId, ExprKind, Ident, ItemKind, Lit, + LocalItemId, LocalVarId, Mutability, Package, PackageId, PackageLookup, PackageStore, Pat, + PatId, PatKind, Res, Result, StmtId, StmtKind, StoreItemId, UnOp, + }, + ty::{Prim, Ty}, +}; +use rustc_hash::{FxHashMap, FxHashSet}; +use std::rc::Rc; +use thiserror::Error; + +use crate::{EMPTY_EXEC_RANGE, reachability::collect_reachable_from_entry}; + +/// Errors that can occur during return unification. +#[derive(Clone, Debug, Diagnostic, Error)] +pub enum Error { + /// The flag-based return unification strategy requires a classical default + /// value for the return type to initialize `__ret_val`. Types such as + /// `Qubit` have no classical default and cannot be handled. + #[error("cannot unify returns of type `{0}` inside a loop")] + #[diagnostic(code("Qsc.ReturnUnify.UnsupportedLoopReturnType"))] + #[diagnostic(help( + "the return type has no classical default value; \ + consider restructuring to avoid returning this type from inside a loop" + ))] + UnsupportedLoopReturnType( + String, + #[label("callable with unsupported return pattern")] Span, + ), +} + +type UdtPureTyCache = FxHashMap<(PackageId, LocalItemId), Ty>; + +/// Recursively collects UDT item references from a type. +/// +/// Walks nested tuples, arrays, and arrows to find all `Ty::Udt` variants and +/// records their `(PackageId, LocalItemId)` identity in `refs`. +fn collect_udt_refs_from_ty(ty: &Ty, refs: &mut FxHashSet<(PackageId, LocalItemId)>) { + match ty { + Ty::Udt(Res::Item(item_id)) => { + refs.insert((item_id.package, item_id.item)); + } + Ty::Array(inner) => collect_udt_refs_from_ty(inner, refs), + Ty::Tuple(tys) => { + for t in tys { + collect_udt_refs_from_ty(t, refs); + } + } + Ty::Arrow(arrow) => { + collect_udt_refs_from_ty(&arrow.input, refs); + collect_udt_refs_from_ty(&arrow.output, refs); + } + _ => {} + } +} + +/// Builds a UDT pure-type cache scoped to UDTs referenced in reachable callable return types. +/// +/// Only resolves `get_pure_ty()` for UDTs that appear in the output types of callables in +/// `reachable`. This avoids scanning all packages × all items when only a fraction of UDTs +/// are actually needed during return unification. +fn build_scoped_udt_pure_ty_cache( + store: &PackageStore, + reachable: &FxHashSet, +) -> UdtPureTyCache { + let mut needed_udts: FxHashSet<(PackageId, LocalItemId)> = FxHashSet::default(); + for item_id in reachable { + let pkg = store.get(item_id.package); + let item = pkg.get_item(item_id.item); + if let ItemKind::Callable(decl) = &item.kind { + collect_udt_refs_from_ty(&decl.output, &mut needed_udts); + } + } + let mut cache = FxHashMap::default(); + for (pkg_id, local_id) in &needed_udts { + let pkg = store.get(*pkg_id); + let item = pkg.get_item(*local_id); + if let ItemKind::Ty(_, udt) = &item.kind { + cache.insert((*pkg_id, *local_id), udt.get_pure_ty()); + } + } + cache +} + +/// Eliminate all `ExprKind::Return` nodes from reachable callable bodies. +/// +/// # Before +/// ```text +/// callable body { ...; return v; ...; trailing } +/// ``` +/// # After +/// ```text +/// callable body { ...; ...; new_trailing } // no ExprKind::Return remains +/// ``` +/// # Requires +/// - `package_id` is present in `store`. +/// - Monomorphization has run (types are concrete). +/// +/// # Ensures +/// - Establishes [`crate::invariants::InvariantLevel::PostReturnUnify`] on +/// top of `PostMono`: no `ExprKind::Return` in reachable bodies. +/// - Each rewritten body's trailing expression produces the callable's +/// return value via if-else lifting or the flag-based transform. +/// +/// # Mutations +/// - Rewrites `CallableDecl` body blocks in `store[package_id]`. +/// - Allocates new FIR nodes through `assigner`. +// +// Only entry-reachable callables are unified. Unreachable callables retain +// their `Return` nodes, but this is safe because: +// 1. `check_no_returns` walks the same reachable set returned by +// [`collect_reachable_from_entry`]. +// 2. Downstream passes (defunc, udt_erase, sroa, arg_promote, +// exec_graph_rebuild) recompute reachability via the same walker and +// never re-reach a callable that was unreachable here. Defunc's +// specialization creates new clone items rather than widening +// reachability to existing-but-dead items. +// 3. A future pass that violates this (for example, inlines a dead call or +// rewires a dead callable into the call graph) must re-invoke +// `unify_returns` on newly reachable items before `check_no_returns` +// runs. +// +// Re-audit trigger: the defunc "tagged-union" future work noted at +// source/compiler/qsc_fir_transforms/src/defunctionalize.rs:42-45 could +// change the reachability story above; this rationale must be re-validated +// if that design lands. Assessment (2026-04): tagged-union +// defunctionalization would create *new* dispatch items (union type + +// apply function) rather than widening reachability to existing dead +// callables, so the invariant is expected to hold. Re-audit if the +// tagged-union design instead reuses or inlines dead callables. +pub fn unify_returns( + store: &mut PackageStore, + package_id: PackageId, + assigner: &mut Assigner, +) -> Vec { + let reachable = collect_reachable_from_entry(store, package_id); + let udt_pure_tys = build_scoped_udt_pure_ty_cache(store, &reachable); + let mut errors = Vec::new(); + + let local_reachable: Vec<_> = reachable + .iter() + .filter(|id| id.package == package_id) + .map(|id| id.item) + .collect(); + + let package = store.get_mut(package_id); + + for item_id in local_reachable { + let callable = { + let item = package.get_item(item_id); + match &item.kind { + ItemKind::Callable(callable) => callable.clone(), + _ => continue, + } + }; + let return_ty = callable.output.clone(); + for block_id in get_callable_body_blocks(&callable) { + if contains_return_in_block(package, block_id) { + // Pre-pass: hoist any compound-position Return to its + // enclosing statement boundary so the strategy pass only sees + // bare returns or returns inside statement-carrying Block/If/While. + normalize::hoist_returns_to_statement_boundary( + package, assigner, package_id, block_id, + ); + if should_use_flag_strategy(package, block_id) { + // The flag strategy requires a classical default for + // `__ret_val`. Check before entering the transform so + // unsupported types (e.g. Qubit) produce a user-facing + // diagnostic instead of panicking. + if !can_create_classical_default(&return_ty, &udt_pure_tys) { + errors.push(Error::UnsupportedLoopReturnType( + format!("{return_ty}"), + callable.name.span, + )); + continue; + } + transform_block_with_flags( + package, + assigner, + package_id, + block_id, + &return_ty, + &udt_pure_tys, + ); + simplify_flag_patterns(package, block_id); + } else { + transform_block_if_else(package, assigner, block_id, &return_ty); + } + } + } + } + + errors +} + +/// Extract every explicit body block from a callable declaration. +/// +/// # Before +/// ```text +/// CallableDecl { implementation: Spec { body, adj?, ctl?, ctl_adj? } } +/// ``` +/// # After +/// ```text +/// [body.block, adj.block?, ctl.block?, ctl_adj.block?] +/// ``` +/// # Requires +/// - `callable` has been lowered to FIR. +/// +/// # Ensures +/// - Returns an empty `Vec` for `CallableImpl::Intrinsic`. +/// - Includes only specializations with an explicit body block. +/// +/// # Mutations +/// - None (read-only). +fn get_callable_body_blocks(callable: &CallableDecl) -> Vec { + // Exhaustive match over CallableImpl. Adding a variant fails to compile + // here; extend the match rather than adding a wildcard. + match &callable.implementation { + CallableImpl::Intrinsic => Vec::new(), + CallableImpl::Spec(spec_impl) => { + let mut blocks = vec![spec_impl.body.block]; + for spec in functored_specs(spec_impl) { + blocks.push(spec.block); + } + blocks + } + CallableImpl::SimulatableIntrinsic(spec) => vec![spec.block], + } +} + +use detect::{contains_return_in_block, contains_return_in_expr, contains_return_in_stmt}; + +/// Returns true if any `Return` node is inside a while loop body. +fn contains_return_in_while(package: &Package, block_id: BlockId) -> bool { + let block = package.get_block(block_id); + block + .stmts + .iter() + .any(|&stmt_id| contains_return_in_while_stmt(package, stmt_id)) +} + +/// Returns true if any `Return` node is reachable through a while-condition +/// expression (at any nesting depth). +fn contains_return_in_while_condition(package: &Package, block_id: BlockId) -> bool { + let block = package.get_block(block_id); + block + .stmts + .iter() + .any(|&stmt_id| contains_return_in_while_condition_stmt(package, stmt_id)) +} + +fn contains_return_in_while_stmt(package: &Package, stmt_id: StmtId) -> bool { + let stmt = package.get_stmt(stmt_id); + match &stmt.kind { + StmtKind::Expr(expr_id) | StmtKind::Semi(expr_id) => { + contains_return_in_while_expr(package, *expr_id) + } + _ => false, + } +} + +fn contains_return_in_while_condition_stmt(package: &Package, stmt_id: StmtId) -> bool { + let stmt = package.get_stmt(stmt_id); + match &stmt.kind { + StmtKind::Expr(expr_id) | StmtKind::Semi(expr_id) | StmtKind::Local(_, _, expr_id) => { + contains_return_in_while_condition_expr(package, *expr_id) + } + StmtKind::Item(_) => false, + } +} + +fn contains_return_in_while_expr(package: &Package, expr_id: ExprId) -> bool { + let expr = package.get_expr(expr_id); + match &expr.kind { + ExprKind::While(_, body_id) => contains_return_in_block(package, *body_id), + ExprKind::Block(block_id) => contains_return_in_while(package, *block_id), + ExprKind::If(_, then_id, else_opt) => { + contains_return_in_while_expr(package, *then_id) + || else_opt.is_some_and(|e| contains_return_in_while_expr(package, e)) + } + _ => false, + } +} + +fn contains_return_in_while_condition_expr(package: &Package, expr_id: ExprId) -> bool { + let expr = package.get_expr(expr_id); + match &expr.kind { + ExprKind::While(cond_id, body_id) => { + contains_return_in_expr(package, *cond_id) + || contains_return_in_while_condition(package, *body_id) + } + ExprKind::Block(block_id) => contains_return_in_while_condition(package, *block_id), + ExprKind::If(cond_id, then_id, else_opt) => { + contains_return_in_while_condition_expr(package, *cond_id) + || contains_return_in_while_condition_expr(package, *then_id) + || else_opt.is_some_and(|e| contains_return_in_while_condition_expr(package, e)) + } + _ => false, + } +} + +/// Returns true when the flag-based strategy is required for `block_id`. +/// +/// The if-else lifting strategy cannot correctly handle either: +/// 1. Returns nested inside while loops (already detected by +/// [`contains_return_in_while`]), or +/// 2. Returns reachable through while-condition expressions (detected by +/// [`contains_return_in_while_condition`]), or +/// 3. "Leaky" early returns inside an if-without-else nested at depth >= 2, +/// where lifting would synthesize an empty-else continuation whose type +/// does not match the non-Unit return type (detected by +/// [`contains_leaky_early_return`]). +/// 4. Returns inside a `Block` expression used as a `Local` initializer, +/// where the structured strategy's `strip_returns_from_block` would +/// consume the return at the block level instead of propagating it to +/// the enclosing callable (detected by +/// [`contains_return_in_block_init`]). +fn should_use_flag_strategy(package: &Package, block_id: BlockId) -> bool { + contains_return_in_while(package, block_id) + || contains_return_in_while_condition(package, block_id) + || contains_leaky_early_return(package, block_id) + || contains_return_in_block_init(package, block_id) +} + +/// Returns true when any `Local` statement in the block has an initializer +/// that is a `Block` expression containing a `Return` inside a nested +/// control-flow construct (`If`, `While`, or inner `Block`), or an `If` +/// expression containing a `Return` inside a `While` loop. Bare returns +/// at the init block's direct statement level are handled correctly by the +/// structured strategy's `transform_local_init` + `apply_bare_return`, so +/// they are excluded. For `If`-expression initializers, +/// `strip_returns_from_expr` handles returns directly in branches and +/// nested if-else chains; only returns inside `While` loops within the +/// branches require the flag strategy. +fn contains_return_in_block_init(package: &Package, block_id: BlockId) -> bool { + let block = package.get_block(block_id); + block.stmts.iter().any(|&stmt_id| { + let stmt = package.get_stmt(stmt_id); + if let StmtKind::Local(_, _, init_id) = &stmt.kind { + let init_expr = package.get_expr(*init_id); + match &init_expr.kind { + ExprKind::Block(inner_block_id) => { + return block_has_nested_return(package, *inner_block_id); + } + ExprKind::If(_, then_id, else_opt) => { + return contains_return_in_while_expr(package, *then_id) + || else_opt.is_some_and(|e| contains_return_in_while_expr(package, e)); + } + _ => {} + } + } + false + }) +} + +/// Returns true when any statement in the block contains a `Return` inside +/// a nested construct (`If`, `While`, inner `Block`) rather than as a bare +/// `Semi(Return(_))` / `Expr(Return(_))`. +fn block_has_nested_return(package: &Package, block_id: BlockId) -> bool { + let block = package.get_block(block_id); + block.stmts.iter().any(|&stmt_id| { + let stmt = package.get_stmt(stmt_id); + let expr_id = match &stmt.kind { + StmtKind::Expr(e) | StmtKind::Semi(e) | StmtKind::Local(_, _, e) => *e, + StmtKind::Item(_) => return false, + }; + let expr = package.get_expr(expr_id); + match &expr.kind { + ExprKind::If(_, then_id, else_opt) => { + contains_return_in_expr(package, *then_id) + || else_opt.is_some_and(|e| contains_return_in_expr(package, e)) + } + ExprKind::While(cond, body) => { + contains_return_in_expr(package, *cond) || contains_return_in_block(package, *body) + } + ExprKind::Block(bid) => contains_return_in_block(package, *bid), + _ => false, + } + }) +} + +/// Returns true if a `Return` appears inside an if-without-else nested at +/// depth >= 2 within the block. Such returns cannot be lifted via the +/// if-else strategy because the synthesized empty-else continuation block +/// would be typed `Unit`, conflicting with a non-Unit callable return type. +fn contains_leaky_early_return(package: &Package, block_id: BlockId) -> bool { + let block = package.get_block(block_id); + block + .stmts + .iter() + .any(|&stmt_id| leaky_early_return_in_stmt(package, stmt_id, 0)) +} + +fn leaky_early_return_in_stmt(package: &Package, stmt_id: StmtId, if_no_else_depth: u32) -> bool { + let stmt = package.get_stmt(stmt_id); + match &stmt.kind { + StmtKind::Expr(expr_id) | StmtKind::Semi(expr_id) | StmtKind::Local(_, _, expr_id) => { + leaky_early_return_in_expr(package, *expr_id, if_no_else_depth) + } + StmtKind::Item(_) => false, + } +} + +fn leaky_early_return_in_expr(package: &Package, expr_id: ExprId, if_no_else_depth: u32) -> bool { + let expr = package.get_expr(expr_id); + match &expr.kind { + ExprKind::Return(_) => if_no_else_depth >= 2, + ExprKind::If(_, then_id, None) => { + leaky_early_return_in_expr(package, *then_id, if_no_else_depth + 1) + } + ExprKind::If(_, then_id, Some(else_id)) => { + leaky_early_return_in_expr(package, *then_id, if_no_else_depth) + || leaky_early_return_in_expr(package, *else_id, if_no_else_depth) + } + ExprKind::Block(bid) => { + let inner = package.get_block(*bid); + inner + .stmts + .iter() + .any(|&s| leaky_early_return_in_stmt(package, s, if_no_else_depth)) + } + // A while whose body transitively contains a return is already + // covered by `contains_return_in_while`; re-report here so any such + // shape triggers the flag strategy through this predicate too. A + // return-free while (e.g., residual structural while after hoist + // moves a return out of a Local initializer) must stay on the + // structured path, otherwise the flag strategy would preserve the + // now-dead while and its references to deleted bindings. + ExprKind::While(_, body) => contains_return_in_block(package, *body), + _ => false, + } +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +enum ReturnFlow { + FallsThrough, + MayReturn, + AlwaysReturns, +} + +impl ReturnFlow { + fn sequence_with(self, next: Self) -> Self { + match (self, next) { + (Self::FallsThrough, flow) => flow, + (Self::AlwaysReturns, _) | (Self::MayReturn, Self::AlwaysReturns) => { + Self::AlwaysReturns + } + (Self::MayReturn, _) => Self::MayReturn, + } + } + + fn from_if_branches(then_flow: Self, else_flow: Option) -> Self { + match else_flow { + Some(else_flow) => match (then_flow, else_flow) { + (Self::AlwaysReturns, Self::AlwaysReturns) => Self::AlwaysReturns, + (Self::FallsThrough, Self::FallsThrough) => Self::FallsThrough, + _ => Self::MayReturn, + }, + None if then_flow == Self::FallsThrough => Self::FallsThrough, + None => Self::MayReturn, + } + } +} + +fn block_return_flow(package: &Package, block_id: BlockId) -> ReturnFlow { + let mut flow = ReturnFlow::FallsThrough; + for &stmt_id in &package.get_block(block_id).stmts { + flow = flow.sequence_with(stmt_return_flow(package, stmt_id)); + if flow == ReturnFlow::AlwaysReturns { + return flow; + } + } + flow +} + +fn stmt_return_flow(package: &Package, stmt_id: StmtId) -> ReturnFlow { + let stmt = package.get_stmt(stmt_id); + match &stmt.kind { + StmtKind::Expr(expr_id) | StmtKind::Semi(expr_id) | StmtKind::Local(_, _, expr_id) => { + expr_return_flow(package, *expr_id) + } + StmtKind::Item(_) => ReturnFlow::FallsThrough, + } +} + +fn sequence_expr_flows( + package: &Package, + expr_ids: impl IntoIterator, +) -> ReturnFlow { + let mut flow = ReturnFlow::FallsThrough; + for expr_id in expr_ids { + flow = flow.sequence_with(expr_return_flow(package, expr_id)); + if flow == ReturnFlow::AlwaysReturns { + return flow; + } + } + flow +} + +fn expr_return_flow(package: &Package, expr_id: ExprId) -> ReturnFlow { + let expr = package.get_expr(expr_id); + match &expr.kind { + ExprKind::Return(_) => ReturnFlow::AlwaysReturns, + ExprKind::Block(block_id) => block_return_flow(package, *block_id), + ExprKind::If(cond_id, then_id, else_opt) => { + let cond_flow = expr_return_flow(package, *cond_id); + let branch_flow = ReturnFlow::from_if_branches( + expr_return_flow(package, *then_id), + else_opt.map(|else_id| expr_return_flow(package, else_id)), + ); + cond_flow.sequence_with(branch_flow) + } + ExprKind::While(cond_id, body_id) => match expr_return_flow(package, *cond_id) { + ReturnFlow::AlwaysReturns => ReturnFlow::AlwaysReturns, + ReturnFlow::MayReturn => ReturnFlow::MayReturn, + ReturnFlow::FallsThrough => match block_return_flow(package, *body_id) { + ReturnFlow::FallsThrough => ReturnFlow::FallsThrough, + ReturnFlow::MayReturn | ReturnFlow::AlwaysReturns => ReturnFlow::MayReturn, + }, + }, + ExprKind::Array(exprs) | ExprKind::ArrayLit(exprs) | ExprKind::Tuple(exprs) => { + sequence_expr_flows(package, exprs.iter().copied()) + } + ExprKind::ArrayRepeat(a_id, b_id) + | ExprKind::Assign(a_id, b_id) + | ExprKind::AssignOp(_, a_id, b_id) + | ExprKind::BinOp(_, a_id, b_id) + | ExprKind::Call(a_id, b_id) + | ExprKind::Index(a_id, b_id) + | ExprKind::AssignField(a_id, _, b_id) + | ExprKind::UpdateField(a_id, _, b_id) => sequence_expr_flows(package, [*a_id, *b_id]), + ExprKind::AssignIndex(a_id, b_id, c_id) | ExprKind::UpdateIndex(a_id, b_id, c_id) => { + sequence_expr_flows(package, [*a_id, *b_id, *c_id]) + } + ExprKind::Fail(inner_id) | ExprKind::Field(inner_id, _) | ExprKind::UnOp(_, inner_id) => { + expr_return_flow(package, *inner_id) + } + ExprKind::Range(start, step, end) => { + let expr_ids = [start, step, end].into_iter().flatten().copied(); + sequence_expr_flows(package, expr_ids) + } + ExprKind::Struct(_, copy, fields) => { + let copy_flow = copy.map_or(ReturnFlow::FallsThrough, |copy_id| { + expr_return_flow(package, copy_id) + }); + let field_flow = sequence_expr_flows(package, fields.iter().map(|field| field.value)); + copy_flow.sequence_with(field_flow) + } + ExprKind::String(components) => { + let expr_ids = components.iter().filter_map(|component| match component { + qsc_fir::fir::StringComponent::Expr(expr_id) => Some(*expr_id), + qsc_fir::fir::StringComponent::Lit(_) => None, + }); + sequence_expr_flows(package, expr_ids) + } + ExprKind::Closure(_, _) | ExprKind::Hole | ExprKind::Lit(_) | ExprKind::Var(_, _) => { + ReturnFlow::FallsThrough + } + } +} + +/// Classification of a statement that contains a Return. +enum ReturnClass { + /// The statement is `Expr(Return(inner))` or `Semi(Return(inner))`. + BareReturn(ExprId), + /// An if where only the then-branch contains a Return. + IfThenReturn { + cond: ExprId, + then_expr: ExprId, + else_opt: Option, + }, + /// An if where both branches contain a Return. + IfBothReturn { + cond: ExprId, + then_expr: ExprId, + else_expr: ExprId, + }, + /// An if where only the else-branch contains a Return. Normalized to + /// `IfThenReturn` with negated condition at dispatch so downstream + /// transform code only handles the then-return shape. + IfElseReturn { + cond: ExprId, + then_expr: ExprId, + else_expr: ExprId, + }, + /// A nested block expression that needs recursive descent. + NestedBlock(BlockId), + /// A Local binding whose init expression contains Returns. + /// The returns must be stripped from the init and types updated. + LocalInit(PatId, ExprId), + /// No return found. + None, +} + +/// Classify a statement's relationship to `ExprKind::Return`. +/// +/// # Before +/// ```text +/// StmtKind::{Expr, Semi}(If/Block/Return/...) | StmtKind::Local(.., init) +/// ``` +/// # After +/// ```text +/// ReturnClass::{BareReturn, IfThenReturn, IfBothReturn, IfElseReturn, +/// NestedBlock, LocalInit, None} +/// ``` +/// # Requires +/// - `kind` is the kind of a statement whose expressions are valid in `package`. +/// +/// # Ensures +/// - Returns the most specific shape matching the statement's surface expression. +/// - Returns `ReturnClass::None` for `StmtKind::Item` and return-free initializers. +/// +/// # Mutations +/// - None (read-only). +/// +/// # Notes +/// +/// `StmtKind::Semi(e)` and `StmtKind::Expr(e)` both collapse to +/// `ReturnClass::BareReturn(*inner)` using the same `inner` expression. This +/// mapping is lossy on purpose: [`apply_bare_return`] discards the source +/// `StmtKind` and synthesizes a fresh `StmtKind::Expr(inner)` trailing +/// statement, then overwrites `block.ty` with the inner expression's type. +/// Downstream callers therefore must not depend on the original `Semi` vs +/// `Expr` kind being preserved for a bare-return statement. +fn classify_return_stmt(package: &Package, kind: &StmtKind) -> ReturnClass { + let expr_id = match kind { + StmtKind::Expr(id) | StmtKind::Semi(id) => *id, + StmtKind::Local(_, pat_id, init_id) => { + return if contains_return_in_expr(package, *init_id) { + ReturnClass::LocalInit(*pat_id, *init_id) + } else { + ReturnClass::None + }; + } + StmtKind::Item(_) => return ReturnClass::None, + }; + + let expr = package.get_expr(expr_id); + match &expr.kind { + ExprKind::Return(inner) => ReturnClass::BareReturn(*inner), + ExprKind::If(cond, then_expr, else_opt) => { + let then_has = contains_return_in_expr(package, *then_expr); + let else_has = else_opt.is_some_and(|e| contains_return_in_expr(package, e)); + match (then_has, else_has) { + (true, true) => ReturnClass::IfBothReturn { + cond: *cond, + then_expr: *then_expr, + else_expr: else_opt.expect("else branch must exist when it contains a return"), + }, + (true, false) => ReturnClass::IfThenReturn { + cond: *cond, + then_expr: *then_expr, + else_opt: *else_opt, + }, + (false, true) => ReturnClass::IfElseReturn { + cond: *cond, + then_expr: *then_expr, + else_expr: else_opt.expect("else branch must exist when it contains a return"), + }, + (false, false) => ReturnClass::None, + } + } + ExprKind::Block(block_id) => { + if contains_return_in_block(package, *block_id) { + ReturnClass::NestedBlock(*block_id) + } else { + ReturnClass::None + } + } + _ => ReturnClass::None, + } +} + +/// Rewrite the first return-containing statement in `block_id` into return-free flow. +/// +/// Finds the first statement that still contains an `ExprKind::Return`, +/// classifies it via [`classify_return_stmt`], and dispatches to the +/// matching per-shape rewriter: +/// +/// | Classification | Dispatched helper | +/// |---------------------|---------------------------------------------------| +/// | `BareReturn` | [`transform_bare_return`] | +/// | `IfThenReturn` | [`apply_if_then_return`] | +/// | `IfBothReturn` | [`apply_if_both_return`] | +/// | `IfElseReturn` | [`transform_if_else_return`] (normalizing rewrite)| +/// | `NestedBlock` | [`transform_nested_block`] | +/// | `LocalInit` | [`transform_local_init`] | +/// | `None` | no-op | +/// +/// # Before +/// ```text +/// { stmts_before; if cond { return v; } stmts_after } // IfThenReturn shape +/// ``` +/// # After +/// ```text +/// { stmts_before; if cond { v } else { stmts_after } } +/// ``` +/// # Requires +/// - `block_id` is valid in `package`. +/// - The normalization pre-pass has run, so Returns only appear at statement +/// boundaries or inside `Block`/`If`/`While` expressions. +/// +/// # Ensures +/// - Returns `true` iff the block was rewritten. +/// - When `true`, the first return-containing statement has been replaced +/// by return-free control flow; recursion may rewrite nested blocks. +/// +/// # Mutations +/// - Rewrites `Block.stmts` and `Block.ty` for `block_id` and reachable +/// sub-blocks via the dispatched helper. +/// - Allocates new FIR nodes through `assigner`. +#[allow(clippy::too_many_lines)] +fn transform_block_if_else( + package: &mut Package, + assigner: &mut Assigner, + block_id: BlockId, + return_ty: &Ty, +) -> bool { + let stmts = package.get_block(block_id).stmts.clone(); + let first_return_idx = stmts + .iter() + .position(|&sid| contains_return_in_stmt(package, sid)); + let Some(idx) = first_return_idx else { + return false; + }; + + let stmt_kind = package.get_stmt(stmts[idx]).kind.clone(); + match classify_return_stmt(package, &stmt_kind) { + ReturnClass::BareReturn(inner_expr_id) => { + transform_bare_return(package, assigner, block_id, idx, inner_expr_id) + } + ReturnClass::IfThenReturn { + cond, + then_expr, + else_opt, + } => { + apply_if_then_return( + package, assigner, block_id, idx, cond, then_expr, else_opt, return_ty, + ); + true + } + ReturnClass::IfBothReturn { + cond, + then_expr, + else_expr, + } => { + apply_if_both_return( + package, assigner, block_id, idx, cond, then_expr, else_expr, return_ty, + ); + true + } + ReturnClass::IfElseReturn { + cond, + then_expr, + else_expr, + } => { + transform_if_else_return( + package, assigner, block_id, return_ty, idx, cond, then_expr, else_expr, + ); + true + } + ReturnClass::NestedBlock(inner_block_id) => transform_nested_block( + package, + assigner, + block_id, + return_ty, + &stmts, + idx, + inner_block_id, + ), + ReturnClass::LocalInit(pat_id, init_expr_id) => transform_local_init( + package, + assigner, + block_id, + return_ty, + idx, + pat_id, + init_expr_id, + ), + ReturnClass::None => false, + } +} + +/// Normalize an `IfElseReturn` (return only in the else branch) to the +/// `IfThenReturn` shape by negating the condition and swapping branches, +/// then delegate to [`apply_if_then_return`]. +/// +/// ```text +/// // Before +/// if cond { then_expr } else { return v; } +/// stmts_after; +/// +/// // After (equivalent to apply_if_then_return on the negated shape) +/// if not cond { v } else { then_expr; stmts_after; } +/// ``` +#[allow(clippy::too_many_arguments)] +fn transform_if_else_return( + package: &mut Package, + assigner: &mut Assigner, + block_id: BlockId, + return_ty: &Ty, + idx: usize, + cond: ExprId, + then_expr: ExprId, + else_expr: ExprId, +) { + // Convert to IfThenReturn by negating the condition and swapping branches. + let neg_cond = alloc_not_expr(package, assigner, cond, Span::default()); + apply_if_then_return( + package, + assigner, + block_id, + idx, + neg_cond, + else_expr, + Some(then_expr), + return_ty, + ); +} + +/// Rewrite a `let` binding whose initializer contains `Return` nodes. +/// +/// When the init always returns, the let and all continuation are dead code — +/// strip returns and replace the block with just the init expression. +/// +/// When the init may return (some paths return, others produce values), +/// decompose the init into a guard statement at the outer block level so the +/// return is visible to `transform_block_if_else`. This avoids stripping +/// returns in-place (which would leave side effects from the return path +/// reachable on the fallthrough path). +/// +/// ```text +/// // Before +/// { +/// let x = if cond { return v; } else { u }; +/// continuation; +/// } +/// +/// // After decomposition (MayReturn case) +/// { +/// if cond { return v; } // guard — return preserved +/// let x = u; // fallthrough value only +/// continuation; +/// } +/// // Then transform_block_if_else handles the guard normally. +/// ``` +fn transform_local_init( + package: &mut Package, + assigner: &mut Assigner, + block_id: BlockId, + return_ty: &Ty, + idx: usize, + pat_id: PatId, + init_expr_id: ExprId, +) -> bool { + let init_flow = expr_return_flow(package, init_expr_id); + + if init_flow == ReturnFlow::AlwaysReturns { + // Everything after this let is dead code. + strip_returns_from_expr(package, assigner, init_expr_id, return_ty); + let new_init_ty = package.get_expr(init_expr_id).ty.clone(); + let new_stmt_id = alloc_expr_stmt(package, assigner, init_expr_id, Span::default()); + let block = package.blocks.get_mut(block_id).expect("block not found"); + block.stmts.truncate(idx); + block.stmts.push(new_stmt_id); + block.ty = new_init_ty; + return true; + } + + // MayReturn: try to decompose the returning init into a guard statement + // at the outer block level so the continuation only runs on fallthrough. + if init_flow == ReturnFlow::MayReturn + && decompose_returning_init( + package, + assigner, + block_id, + return_ty, + idx, + pat_id, + init_expr_id, + ) + { + return true; + } + + // FallsThrough or failed decomposition: strip returns and retype. + strip_returns_from_expr(package, assigner, init_expr_id, return_ty); + let new_init_ty = package.get_expr(init_expr_id).ty.clone(); + + // Update the pattern's type to match the stripped init. + let local_var_id = match &package.get_pat(pat_id).kind { + PatKind::Bind(ident) => Some(ident.id), + _ => None, + }; + let pat = package.pats.get_mut(pat_id).expect("pat not found"); + pat.ty = new_init_ty.clone(); + + // Update all Var references to this local in the block. + if let Some(var_id) = local_var_id { + let block_stmts = package.get_block(block_id).stmts.clone(); + for &stmt_id in &block_stmts { + update_local_var_type(package, stmt_id, var_id, &new_init_ty); + } + } + + // Re-analyze this block after stripping so any newly exposed + // returns or nested wrappers are normalized, then synchronize the + // block type with its new trailing expression. + let _ = transform_block_if_else(package, assigner, block_id, return_ty); + sync_block_type_to_trailing_expr(package, block_id); + true +} + +/// Decompose a `MayReturn` init expression into a guard statement at the +/// outer block level, preserving the return so `transform_block_if_else` +/// can handle it with proper continuation threading. +/// +/// Handles `if cond { RETURN_BRANCH } else { VALUE_BRANCH }` patterns +/// (and the inverse) by extracting the return-bearing branch into a +/// preceding guard statement and replacing the init with just the value. +/// +/// Returns `true` if decomposition succeeded and the block was restructured. +#[allow(clippy::too_many_arguments)] +fn decompose_returning_init( + package: &mut Package, + assigner: &mut Assigner, + block_id: BlockId, + return_ty: &Ty, + idx: usize, + pat_id: PatId, + init_expr_id: ExprId, +) -> bool { + let init_kind = package.get_expr(init_expr_id).kind.clone(); + + match init_kind { + ExprKind::If(cond_id, then_id, Some(else_id)) => { + let then_flow = expr_return_flow(package, then_id); + let else_flow = expr_return_flow(package, else_id); + + match (then_flow, else_flow) { + (ReturnFlow::AlwaysReturns | ReturnFlow::MayReturn, ReturnFlow::FallsThrough) => { + // Then branch returns, else is the fallthrough value. + // Insert: if cond { then_branch } as guard (return preserved!) + // Replace init with: else value + extract_guard_and_replace_init( + package, + assigner, + block_id, + return_ty, + idx, + pat_id, + init_expr_id, + cond_id, + then_id, + else_id, + ); + true + } + (ReturnFlow::FallsThrough, ReturnFlow::AlwaysReturns | ReturnFlow::MayReturn) => { + // Else branch returns, then is the fallthrough value. + // Negate condition and swap. + let neg_cond = alloc_not_expr(package, assigner, cond_id, Span::default()); + extract_guard_and_replace_init( + package, + assigner, + block_id, + return_ty, + idx, + pat_id, + init_expr_id, + neg_cond, + else_id, + then_id, + ); + true + } + _ => false, + } + } + ExprKind::Block(inner_block_id) => { + // Unwrap block and try to decompose the trailing expression. + let inner_stmts = package.get_block(inner_block_id).stmts.clone(); + let Some(&tail_stmt_id) = inner_stmts.last() else { + return false; + }; + let tail_stmt = package.get_stmt(tail_stmt_id); + let tail_expr_id = match &tail_stmt.kind { + StmtKind::Expr(expr_id) | StmtKind::Semi(expr_id) => *expr_id, + _ => return false, + }; + // If the block has prefix statements before the tail, we can't + // simply decompose — the prefix needs to stay in scope. + if inner_stmts.len() > 1 { + return false; + } + decompose_returning_init( + package, + assigner, + block_id, + return_ty, + idx, + pat_id, + tail_expr_id, + ) + } + _ => false, + } +} + +/// Extract a return-bearing branch from an if-expression init into a guard +/// statement, replacing the init with the fallthrough value. +/// +/// ```text +/// // Before +/// let x = if cond { return v; } else { u }; +/// continuation; +/// +/// // After +/// if cond { return v; } // guard stmt inserted before the let +/// let x = u; // init replaced with fallthrough value +/// continuation; +/// ``` +#[allow(clippy::too_many_arguments)] +fn extract_guard_and_replace_init( + package: &mut Package, + assigner: &mut Assigner, + block_id: BlockId, + return_ty: &Ty, + idx: usize, + pat_id: PatId, + init_expr_id: ExprId, + cond_id: ExprId, + return_branch_id: ExprId, + value_branch_id: ExprId, +) { + // Create the guard: if cond { return_branch } (no else — it's a guard) + let guard_if = { + let ty: &Ty = &Ty::UNIT; + alloc_if_expr( + package, + assigner, + cond_id, + return_branch_id, + None, + ty.clone(), + Span::default(), + ) + }; + let guard_stmt = alloc_semi_stmt(package, assigner, guard_if, Span::default()); + + // Replace the init expression with the fallthrough value. + let value_expr = package.get_expr(value_branch_id).clone(); + let init = package + .exprs + .get_mut(init_expr_id) + .expect("init expr not found"); + init.kind = value_expr.kind; + init.ty = value_expr.ty.clone(); + init.exec_graph_range = EMPTY_EXEC_RANGE; + + // Retype the pattern to match the new init type. + let value_ty = value_expr.ty; + let pat = package.pats.get_mut(pat_id).expect("pat not found"); + pat.ty = value_ty.clone(); + + // Update all Var references to this local. + if let PatKind::Bind(ident) = &package.get_pat(pat_id).kind { + let var_id = ident.id; + let block_stmts = package.get_block(block_id).stmts.clone(); + for &stmt_id in &block_stmts { + update_local_var_type(package, stmt_id, var_id, &value_ty); + } + } + + // Insert the guard statement before the let binding. + let block = package.blocks.get_mut(block_id).expect("block not found"); + block.stmts.insert(idx, guard_stmt); + + // Re-analyze the block — transform_block_if_else will find the guard + // and move the continuation (including the let + subsequent stmts) into + // the else branch, which is the correct control flow. + let _ = transform_block_if_else(package, assigner, block_id, return_ty); + sync_block_type_to_trailing_expr(package, block_id); +} + +/// Recursively rewrite an inner block wrapped in a statement-position `Block` expression. +/// +/// After the inner block is rewritten, this helper: +/// 1. Retypes the wrapper `Block` expression to match the inner block's new type. +/// 2. Promotes a trailing `Semi` wrapper to `Expr` so the inner value flows +/// out as the enclosing block's trailing expression. +/// +/// # Before +/// ```text +/// { stmts_before; { stmts_inner; if cond { return v; } } } +/// ``` +/// # After +/// ```text +/// { +/// stmts_before; +/// { stmts_inner; if cond { v } else { () } } +/// } +/// ``` +/// # Requires +/// - `inner_block_id` is the body of the `Block` expression at `stmts[idx]`. +/// - `stmts` is the current statement list of `block_id`. +/// - The normalization pre-pass has run. +/// +/// # Ensures +/// - Returns `true` iff inner or outer rewriting made progress. +/// - When the inner transform cannot proceed, returns `false` without +/// recursing to avoid infinite recursion. +/// - Block types remain consistent with their trailing expressions +/// (via `sync_block_type_to_trailing_expr` when needed). +/// +/// # Mutations +/// - Rewrites the wrapper `Expr.ty` at `stmts[idx]`. +/// - May rewrite `Block.stmts`, `Block.ty`, and statement kinds for both +/// inner and outer blocks. +/// - Allocates new FIR nodes through `assigner`. +#[allow(clippy::too_many_arguments)] +fn transform_nested_block( + package: &mut Package, + assigner: &mut Assigner, + block_id: BlockId, + return_ty: &Ty, + stmts: &[StmtId], + idx: usize, + inner_block_id: BlockId, +) -> bool { + // Before transforming the inner block, check whether its first + // return-containing statement is unconditional (BareReturn or + // IfBothReturn). When the return is unconditional, ALL code paths + // through the inner block end in a return, so statements after + // this nested-block wrapper in the outer block are dead code. + let is_unconditional_return = + block_return_flow(package, inner_block_id) == ReturnFlow::AlwaysReturns; + + let inner_changed = transform_block_if_else(package, assigner, inner_block_id, return_ty); + + // If the inner block couldn't be transformed (e.g. the return is + // inside a While that must be handled by the flag-based path), + // stop to avoid infinite recursion. + if !inner_changed { + return false; + } + + // Update the Block expression's type to match the inner block's new type. + let new_inner_ty = package.get_block(inner_block_id).ty.clone(); + let wrapper_expr_id = match &package.get_stmt(stmts[idx]).kind { + StmtKind::Expr(e) | StmtKind::Semi(e) => *e, + _ => unreachable!("NestedBlock must be Expr or Semi"), + }; + let e = package + .exprs + .get_mut(wrapper_expr_id) + .expect("expr not found"); + e.ty = new_inner_ty.clone(); + + // When the inner block's return was unconditional and there are + // statements after this one in the outer block, all subsequent + // statements are dead code. Truncate them and promote the block + // expression to the outer block's trailing expression. + if is_unconditional_return && idx < stmts.len() - 1 { + apply_bare_return(package, assigner, block_id, idx, wrapper_expr_id); + return true; + } + + // If this is the last statement and is Semi, promote to Expr so the + // value flows through as the block's trailing expression. + if idx == stmts.len() - 1 { + let stmt = package.stmts.get_mut(stmts[idx]).expect("stmt not found"); + if matches!(stmt.kind, StmtKind::Semi(_)) { + stmt.kind = StmtKind::Expr(wrapper_expr_id); + } + let block = package.blocks.get_mut(block_id).expect("block not found"); + block.ty = new_inner_ty; + } else { + sync_block_type_to_trailing_expr(package, block_id); + } + + // Re-analyze this block after inner transform. + let outer_changed = transform_block_if_else(package, assigner, block_id, return_ty); + inner_changed || outer_changed +} + +/// Rewrite a bare-return statement into a trailing expression. +/// +/// Runs [`apply_bare_return`] to drop post-return statements and install +/// the return value as the block's trailing expression. +/// +/// ```text +/// // Before +/// { +/// stmts_before; +/// return v; +/// stmts_dead; +/// } +/// +/// // After +/// { +/// stmts_before; +/// v +/// } +/// ``` +fn transform_bare_return( + package: &mut Package, + assigner: &mut Assigner, + block_id: BlockId, + idx: usize, + inner_expr_id: ExprId, +) -> bool { + apply_bare_return(package, assigner, block_id, idx, inner_expr_id); + true +} + +/// Synchronize a block's type with the type of its trailing expression. +/// +/// # Before +/// ```text +/// Block { stmts: [..., Expr(e: T)], ty: U } // U may have drifted from T +/// ``` +/// # After +/// ```text +/// Block { stmts: [..., Expr(e: T)], ty: T } +/// ``` +/// # Requires +/// - `block_id` is valid in `package`. +/// +/// # Ensures +/// - `Block.ty == trailing_expr.ty` after return if the block ends in a +/// `StmtKind::Expr`. +/// - No-op when the block is empty or ends in a non-expression statement. +/// +/// # Mutations +/// - Writes `Block.ty` for `block_id` in place. +fn sync_block_type_to_trailing_expr(package: &mut Package, block_id: BlockId) { + let Some(&stmt_id) = package.get_block(block_id).stmts.last() else { + return; + }; + + let StmtKind::Expr(expr_id) = package.get_stmt(stmt_id).kind else { + return; + }; + + let trailing_ty = package.get_expr(expr_id).ty.clone(); + let block = package.blocks.get_mut(block_id).expect("block not found"); + block.ty = trailing_ty; +} + +/// Strongly sync a block's type to its tail: the trailing `Expr`'s type +/// when present, otherwise `Unit`. Used by the flag-strategy's Return +/// replacement so nested blocks whose trailing `Return(v)` expression was +/// typed to the callable return type get their type refreshed to `Unit` +/// once the Return has been replaced with a Unit flag-assignment block. +fn sync_block_type_to_stmt_or_unit(package: &mut Package, block_id: BlockId) { + let trailing_ty = match package.get_block(block_id).stmts.last() { + Some(&stmt_id) => match package.get_stmt(stmt_id).kind { + StmtKind::Expr(expr_id) => package.get_expr(expr_id).ty.clone(), + _ => Ty::UNIT, + }, + None => Ty::UNIT, + }; + let block = package.blocks.get_mut(block_id).expect("block not found"); + block.ty = trailing_ty; +} + +/// Replace a block's statements at and after `idx` with a single trailing +/// expression carrying the returned value. +/// +/// ```text +/// // Before (stmt at idx is Expr(Return(inner)) or Semi(Return(inner))) +/// { +/// stmts[..idx]; +/// return inner; +/// stmts[idx+1..]; +/// } +/// +/// // After +/// { +/// stmts[..idx]; +/// inner +/// } +/// ``` +/// +/// Also updates the block's type to the type of `inner`. +fn apply_bare_return( + package: &mut Package, + assigner: &mut Assigner, + block_id: BlockId, + idx: usize, + inner_expr_id: ExprId, +) { + let inner_ty = package.get_expr(inner_expr_id).ty.clone(); + + // Create a new trailing-expression statement for the inner value. + let new_stmt_id = alloc_expr_stmt(package, assigner, inner_expr_id, Span::default()); + + let block = package.blocks.get_mut(block_id).expect("block not found"); + block.stmts.truncate(idx); + block.stmts.push(new_stmt_id); + block.ty = inner_ty; +} + +/// Rewrite an `if cond { return v; }` guard-clause into a return-free +/// `if cond { v } else { /* continuation */ }` trailing expression. +/// +/// Statements after the `if` become the new else branch (recursively +/// rewritten via [`transform_block_if_else`]). When the original `if` +/// already had an else, it is preserved as a leading `Semi` statement +/// inside that new else block. +/// +/// ```text +/// // Before +/// { +/// stmts_before; +/// if cond { return v; } else { side_effect; } +/// rest; +/// } +/// +/// // After +/// { +/// stmts_before; +/// if cond { v } else { side_effect; rest; } +/// } +/// ``` +#[allow(clippy::too_many_arguments)] +#[allow(clippy::too_many_lines)] +fn apply_if_then_return( + package: &mut Package, + assigner: &mut Assigner, + block_id: BlockId, + idx: usize, + cond: ExprId, + then_expr: ExprId, + else_opt: Option, + return_ty: &Ty, +) { + // Strip returns from the then branch. + strip_returns_from_expr(package, assigner, then_expr, return_ty); + + // Collect remaining statements after the if. + let remaining_stmts: Vec = package.get_block(block_id).stmts[idx + 1..].to_vec(); + + let then_flow = expr_return_flow(package, then_expr); + + if then_flow == ReturnFlow::AlwaysReturns { + let new_else_expr_id = create_fallthrough_continuation_expr( + package, + assigner, + else_opt, + remaining_stmts, + return_ty, + ); + let new_if_expr_id = { + let else_expr = Some(new_else_expr_id); + alloc_if_expr( + package, + assigner, + cond, + then_expr, + else_expr, + return_ty.clone(), + Span::default(), + ) + }; + let new_tail = alloc_expr_stmt(package, assigner, new_if_expr_id, Span::default()); + + let block = package.blocks.get_mut(block_id).expect("block not found"); + block.stmts.truncate(idx); + block.stmts.push(new_tail); + block.ty = return_ty.clone(); + return; + } + + if let Some(else_expr_id) = else_opt.filter(|_| remaining_stmts.is_empty()) { + // No remaining statements; keep the existing else as-is but strip returns. + strip_returns_from_expr(package, assigner, else_expr_id, return_ty); + + // Create the new if expression using the existing else. + let new_if_expr_id = { + let else_expr = Some(else_expr_id); + alloc_if_expr( + package, + assigner, + cond, + then_expr, + else_expr, + return_ty.clone(), + Span::default(), + ) + }; + let new_tail = alloc_expr_stmt(package, assigner, new_if_expr_id, Span::default()); + + let block = package.blocks.get_mut(block_id).expect("block not found"); + block.stmts.truncate(idx); + block.stmts.push(new_tail); + block.ty = return_ty.clone(); + return; + } + + // Build a new block from the remaining statements (plus any existing else content). + let new_else_block_id = if let Some(else_expr_id) = else_opt { + // Prepend the existing else as a Semi statement in the new continuation block. + let else_semi = alloc_semi_stmt(package, assigner, else_expr_id, Span::default()); + let mut new_stmts = vec![else_semi]; + new_stmts.extend(remaining_stmts); + alloc_block( + package, + assigner, + new_stmts, + return_ty.clone(), + Span::default(), + ) + } else { + if remaining_stmts.is_empty() { + // Invariant: `should_use_flag_strategy` routes non-Unit leaky + // early-return shapes to the flag strategy, so this empty-else + // synthesis is only reachable for Unit-typed returns. + assert!( + *return_ty == Ty::UNIT, + "apply_if_then_return reached empty-else for non-Unit return type — \ + should have been routed through transform_block_with_flags" + ); + } + alloc_block( + package, + assigner, + remaining_stmts, + return_ty.clone(), + Span::default(), + ) + }; + + // Recursively transform the new else block (it may contain more returns). + transform_block_if_else(package, assigner, new_else_block_id, return_ty); + + // Create new else expression wrapping the block. + let new_else_expr_id = alloc_block_expr( + package, + assigner, + new_else_block_id, + return_ty.clone(), + Span::default(), + ); + + // Create the new if expression. + let new_if_expr_id = { + let else_expr = Some(new_else_expr_id); + alloc_if_expr( + package, + assigner, + cond, + then_expr, + else_expr, + return_ty.clone(), + Span::default(), + ) + }; + let new_tail = alloc_expr_stmt(package, assigner, new_if_expr_id, Span::default()); + + let block = package.blocks.get_mut(block_id).expect("block not found"); + block.stmts.truncate(idx); + block.stmts.push(new_tail); + block.ty = return_ty.clone(); +} + +/// Rewrite an `if cond { return a; } else { return b; }` into a return-free +/// value-producing tail. If release statements follow the if, capture the +/// selected value before the releases and leave the captured value as the +/// block's trailing expression. +/// +/// ```text +/// // Before +/// { +/// stmts_before; +/// if cond { return a; } else { return b; } +/// release_call(); +/// stmts_dead; +/// } +/// +/// // After +/// { +/// stmts_before; +/// let __return_unify_result = if cond { a } else { b }; +/// release_call(); +/// __return_unify_result +/// } +/// ``` +#[allow(clippy::too_many_arguments)] +fn apply_if_both_return( + package: &mut Package, + assigner: &mut Assigner, + block_id: BlockId, + idx: usize, + cond: ExprId, + then_expr: ExprId, + else_expr: ExprId, + return_ty: &Ty, +) { + strip_returns_from_expr(package, assigner, then_expr, return_ty); + strip_returns_from_expr(package, assigner, else_expr, return_ty); + + let new_if_expr_id = { + let else_expr = Some(else_expr); + alloc_if_expr( + package, + assigner, + cond, + then_expr, + else_expr, + return_ty.clone(), + Span::default(), + ) + }; + + // Both branches contain returns, so any statements after the if are dead. + let new_tail = alloc_expr_stmt(package, assigner, new_if_expr_id, Span::default()); + + let block = package.blocks.get_mut(block_id).expect("block not found"); + block.stmts.truncate(idx); + block.stmts.push(new_tail); + block.ty = return_ty.clone(); +} + +/// Strip `ExprKind::Return` nodes from an expression tree in place. +/// +/// Lifts returned values to take the place of the `Return` wrapper, and +/// retypes enclosing `Block` and `If` expressions so the lifted value's +/// type propagates outward (`()` → `return_ty`). +/// +/// # Before +/// ```text +/// return v // ExprKind::Return(v) : () +/// { stmts; return v; } // ExprKind::Block : () +/// ``` +/// # After +/// ```text +/// v // v.kind : T +/// { stmts; v } // ExprKind::Block : T +/// ``` +/// # Requires +/// - `expr_id` is valid in `package`. +/// - `return_ty` is the enclosing callable's return type. +/// +/// # Ensures +/// - Every `ExprKind::Return` reachable through `Block`/`If`/compound +/// descent is replaced with the inner value. +/// - `Block` and `If` expression types are refreshed to propagate the +/// lifted value's type. +/// +/// # Mutations +/// - Rewrites `Expr` nodes in place via `package.exprs.get_mut`. +/// - Rewrites nested `Block` contents via [`strip_returns_from_block`]. +/// - Allocates new FIR nodes through `assigner` where required. +#[allow(clippy::too_many_lines)] +fn strip_returns_from_expr( + package: &mut Package, + assigner: &mut Assigner, + expr_id: ExprId, + return_ty: &Ty, +) { + let expr = package.get_expr(expr_id).clone(); + match &expr.kind { + ExprKind::Return(inner) => { + let inner_expr = package.get_expr(*inner).clone(); + let e = package.exprs.get_mut(expr_id).expect("expr not found"); + *e = Expr { + id: expr_id, + span: expr.span, + ty: inner_expr.ty.clone(), + kind: inner_expr.kind.clone(), + exec_graph_range: EMPTY_EXEC_RANGE, + }; + // Recursively strip in case the inner also has returns. + strip_returns_from_expr(package, assigner, expr_id, return_ty); + } + ExprKind::Block(block_id) => { + let bid = *block_id; + strip_returns_from_block(package, assigner, bid, return_ty); + // Update the Block expression's type to match the block's new type. + let new_block_ty = package.get_block(bid).ty.clone(); + let e = package.exprs.get_mut(expr_id).expect("expr not found"); + e.ty = new_block_ty; + } + ExprKind::If(_, then_expr, else_opt) => { + let then_id = *then_expr; + let else_id = *else_opt; + strip_returns_from_expr(package, assigner, then_id, return_ty); + if let Some(e) = else_id { + strip_returns_from_expr(package, assigner, e, return_ty); + } + // Update the If expression's type to match the return type. + let e = package.exprs.get_mut(expr_id).expect("expr not found"); + e.ty = return_ty.clone(); + } + // Compound-expression descent. Sub-expressions are visited so any + // `Return` nested through these kinds after normalization is still + // stripped defensively. Types of these kinds are not refreshed because + // valid normalized FIR should not leave return-bearing values here. + ExprKind::Array(exprs) | ExprKind::ArrayLit(exprs) | ExprKind::Tuple(exprs) => { + let ids: Vec = exprs.clone(); + for e in ids { + strip_returns_from_expr(package, assigner, e, return_ty); + } + } + ExprKind::ArrayRepeat(a, b) + | ExprKind::Assign(a, b) + | ExprKind::AssignOp(_, a, b) + | ExprKind::BinOp(_, a, b) + | ExprKind::Call(a, b) + | ExprKind::Index(a, b) + | ExprKind::AssignField(a, _, b) + | ExprKind::UpdateField(a, _, b) => { + let (a_id, b_id) = (*a, *b); + strip_returns_from_expr(package, assigner, a_id, return_ty); + strip_returns_from_expr(package, assigner, b_id, return_ty); + } + ExprKind::AssignIndex(a, b, c) | ExprKind::UpdateIndex(a, b, c) => { + let (a_id, b_id, c_id) = (*a, *b, *c); + strip_returns_from_expr(package, assigner, a_id, return_ty); + strip_returns_from_expr(package, assigner, b_id, return_ty); + strip_returns_from_expr(package, assigner, c_id, return_ty); + } + ExprKind::Fail(e) | ExprKind::Field(e, _) | ExprKind::UnOp(_, e) => { + let sub = *e; + strip_returns_from_expr(package, assigner, sub, return_ty); + } + ExprKind::Range(start, step, end) => { + let ids: Vec = [start, step, end].into_iter().flatten().copied().collect(); + for e in ids { + strip_returns_from_expr(package, assigner, e, return_ty); + } + } + ExprKind::Struct(_, copy, fields) => { + let copy_id = *copy; + let field_ids: Vec = fields.iter().map(|fa| fa.value).collect(); + if let Some(c) = copy_id { + strip_returns_from_expr(package, assigner, c, return_ty); + } + for e in field_ids { + strip_returns_from_expr(package, assigner, e, return_ty); + } + } + ExprKind::String(components) => { + let ids: Vec = components + .iter() + .filter_map(|c| match c { + qsc_fir::fir::StringComponent::Expr(e) => Some(*e), + qsc_fir::fir::StringComponent::Lit(_) => None, + }) + .collect(); + for e in ids { + strip_returns_from_expr(package, assigner, e, return_ty); + } + } + ExprKind::While(cond, body) => { + let (cond_id, body_id) = (*cond, *body); + strip_returns_from_expr(package, assigner, cond_id, return_ty); + // Walk every statement-level expression inside the while body. + let stmts = package.get_block(body_id).stmts.clone(); + for stmt_id in stmts { + let expr_ids: Vec = { + let stmt = package.get_stmt(stmt_id); + match &stmt.kind { + StmtKind::Expr(e) | StmtKind::Semi(e) | StmtKind::Local(_, _, e) => { + vec![*e] + } + StmtKind::Item(_) => vec![], + } + }; + for e in expr_ids { + strip_returns_from_expr(package, assigner, e, return_ty); + } + } + } + ExprKind::Closure(_, _) | ExprKind::Hole | ExprKind::Lit(_) | ExprKind::Var(_, _) => {} + } +} + +/// Strip returns from a block by transforming it with if-else lifting, +/// using the function's return type rather than the block's own type. +fn strip_returns_from_block( + package: &mut Package, + assigner: &mut Assigner, + block_id: BlockId, + return_ty: &Ty, +) { + transform_block_if_else(package, assigner, block_id, return_ty); +} + +/// Retype every `Var(Local(var_id))` expression reachable from a statement +/// to `new_ty`. +/// +/// Used after [`transform_local_init`] strips returns from a `let` +/// initializer, to keep reads of the bound local type-consistent with the +/// newly-lifted init type. +/// +/// ```text +/// // Before (init type lifted from () to T after strip_returns_from_expr) +/// let x : () = { ... }; // x reads typed () +/// +/// // After +/// let x : T = { ... }; // every Var(x) retyped to T +/// ``` +fn update_local_var_type(package: &mut Package, stmt_id: StmtId, var_id: LocalVarId, new_ty: &Ty) { + let expr_ids: Vec = { + let stmt = package.get_stmt(stmt_id); + match &stmt.kind { + StmtKind::Expr(e) | StmtKind::Semi(e) | StmtKind::Local(_, _, e) => vec![*e], + StmtKind::Item(_) => vec![], + } + }; + for expr_id in expr_ids { + update_local_var_type_in_expr(package, expr_id, var_id, new_ty); + } +} + +/// Recursively retype every `Var(Local(var_id))` read inside an expression tree. +/// +/// # Before +/// ```text +/// Var(Local(var_id)) : OldTy // anywhere in the subtree +/// ``` +/// # After +/// ```text +/// Var(Local(var_id)) : NewTy +/// ``` +/// # Requires +/// - `expr_id` is valid in `package`. +/// - `var_id` is the binding whose referencing `Var`s must be retyped. +/// +/// # Ensures +/// - Every `Var(Local(var_id))` reachable through `Block`/`If`/compound +/// descent has its `Expr.ty` set to `new_ty`. +/// - Does not touch `Var`s resolving to other locals or non-local `Res`. +/// +/// # Mutations +/// - Writes `Expr.ty` in place for each matching `Var` node. +fn update_local_var_type_in_expr( + package: &mut Package, + expr_id: ExprId, + var_id: LocalVarId, + new_ty: &Ty, +) { + let kind = package.get_expr(expr_id).kind.clone(); + match &kind { + ExprKind::Var(Res::Local(id), _) if *id == var_id => { + let e = package.exprs.get_mut(expr_id).expect("expr not found"); + e.ty = new_ty.clone(); + } + ExprKind::Block(block_id) => { + let stmts = package.get_block(*block_id).stmts.clone(); + for stmt_id in stmts { + update_local_var_type(package, stmt_id, var_id, new_ty); + } + } + ExprKind::If(_, then_id, else_opt) => { + update_local_var_type_in_expr(package, *then_id, var_id, new_ty); + if let Some(e) = *else_opt { + update_local_var_type_in_expr(package, e, var_id, new_ty); + } + } + // Exhaustive descent through every compound `ExprKind`. Closes G3: + // a retype request must reach every `Var(Local(var_id))` read no + // matter how deeply nested it is, not only those inside Block/If. + ExprKind::Array(exprs) | ExprKind::ArrayLit(exprs) | ExprKind::Tuple(exprs) => { + let ids: Vec = exprs.clone(); + for e in ids { + update_local_var_type_in_expr(package, e, var_id, new_ty); + } + } + ExprKind::ArrayRepeat(a, b) + | ExprKind::Assign(a, b) + | ExprKind::AssignOp(_, a, b) + | ExprKind::BinOp(_, a, b) + | ExprKind::Call(a, b) + | ExprKind::Index(a, b) + | ExprKind::AssignField(a, _, b) + | ExprKind::UpdateField(a, _, b) => { + let (a_id, b_id) = (*a, *b); + update_local_var_type_in_expr(package, a_id, var_id, new_ty); + update_local_var_type_in_expr(package, b_id, var_id, new_ty); + } + ExprKind::AssignIndex(a, b, c) | ExprKind::UpdateIndex(a, b, c) => { + let (a_id, b_id, c_id) = (*a, *b, *c); + update_local_var_type_in_expr(package, a_id, var_id, new_ty); + update_local_var_type_in_expr(package, b_id, var_id, new_ty); + update_local_var_type_in_expr(package, c_id, var_id, new_ty); + } + ExprKind::Fail(e) | ExprKind::Field(e, _) | ExprKind::Return(e) | ExprKind::UnOp(_, e) => { + let sub = *e; + update_local_var_type_in_expr(package, sub, var_id, new_ty); + } + ExprKind::Range(start, step, end) => { + let ids: Vec = [start, step, end].into_iter().flatten().copied().collect(); + for e in ids { + update_local_var_type_in_expr(package, e, var_id, new_ty); + } + } + ExprKind::Struct(_, copy, fields) => { + let copy_id = *copy; + let field_ids: Vec = fields.iter().map(|fa| fa.value).collect(); + if let Some(c) = copy_id { + update_local_var_type_in_expr(package, c, var_id, new_ty); + } + for e in field_ids { + update_local_var_type_in_expr(package, e, var_id, new_ty); + } + } + ExprKind::String(components) => { + let ids: Vec = components + .iter() + .filter_map(|c| match c { + qsc_fir::fir::StringComponent::Expr(e) => Some(*e), + qsc_fir::fir::StringComponent::Lit(_) => None, + }) + .collect(); + for e in ids { + update_local_var_type_in_expr(package, e, var_id, new_ty); + } + } + ExprKind::While(cond, body) => { + let (cond_id, body_id) = (*cond, *body); + update_local_var_type_in_expr(package, cond_id, var_id, new_ty); + let stmts = package.get_block(body_id).stmts.clone(); + for stmt_id in stmts { + update_local_var_type(package, stmt_id, var_id, new_ty); + } + } + ExprKind::Var(_, _) | ExprKind::Closure(_, _) | ExprKind::Hole | ExprKind::Lit(_) => {} + } +} + +/// Rewrite a block containing returns inside while loops using the flag-based strategy. +/// +/// Introduces two mutable locals at the top of the block: +/// * `__has_returned : Bool = false` — set when a return fires. +/// * `__ret_val : T = default(T)` — holds the returned value (never read +/// unless `__has_returned` is `true`). +/// +/// Each while loop containing a return has its condition conjoined with +/// `not __has_returned` (see [`transform_while_stmt`]), and returns inside +/// its body are rewritten by [`replace_returns_with_flags`]. Statements +/// after the first return-bearing while are wrapped by +/// [`guard_stmt_with_flag`], including release calls. A trailing +/// `if __has_returned { __ret_val } +/// else { original_trailing }` is appended by [`create_flag_trailing_expr`]. +/// A final call to [`transform_block_if_else`] mops up any non-while +/// returns that remain. +/// +/// # Before +/// ```text +/// { +/// mutable r = 0; +/// while cond { +/// if done { return r; } +/// r += 1; +/// } +/// r +/// } +/// ``` +/// # After +/// ```text +/// { +/// mutable __has_returned = false; +/// mutable __ret_val = 0; +/// mutable r = 0; +/// while not __has_returned and cond { +/// if done { __ret_val = r; __has_returned = true; } +/// else { r += 1; } +/// } +/// if __has_returned { __ret_val } else { r } +/// } +/// ``` +/// # Requires +/// - `block_id` is valid in `package`. +/// - `return_ty` has a synthesizable classical default (see +/// [`create_default_value_kind`]); otherwise this triggers the +/// unsupported-default internal contract panic. +/// +/// # Ensures +/// - While loops exit promptly once `__has_returned` is set. +/// - Post-return continuation statements execute only when no return has fired. +/// - The block's trailing expression produces the return value. +/// +/// # Mutations +/// - Prepends `__has_returned` / `__ret_val` `Local` statements to `block.stmts`. +/// - Rewrites statements carrying returns (while loops, guarded reads, trailing expr). +/// - Allocates new FIR nodes through `assigner`. +#[allow(clippy::too_many_lines)] +fn transform_block_with_flags( + package: &mut Package, + assigner: &mut Assigner, + package_id: PackageId, + block_id: BlockId, + return_ty: &Ty, + udt_pure_tys: &UdtPureTyCache, +) { + // Create __has_returned: Bool = false + let (has_returned_var_id, has_returned_decl_stmt) = + create_mutable_bool_var(package, assigner, "__has_returned", false); + + // Create __ret_val: T = default(T). + // + // For callable-valued return types, `create_default_value` synthesizes + // a nop callable item of the matching signature and returns a + // `Var(Res::Item(..))` reference to it; any later `Call(Var(__ret_val), .)` + // then resolves to that nop (its body returns the output type's default). + // The nop is never actually invoked because `__has_returned` guards + // every read of `__ret_val`, but it keeps the flag-fallback well-typed. + let default_val = require_classical_default( + package, + assigner, + package_id, + return_ty, + udt_pure_tys, + UnsupportedDefaultSite::ReturnSlot, + ); + let (ret_val_var_id, ret_val_decl_stmt) = { + let mutability = Mutability::Mutable; + alloc_local_var( + package, + assigner, + "__ret_val", + return_ty, + default_val, + mutability, + ) + }; + + let original_stmts = package.get_block(block_id).stmts.clone(); + let mut new_stmts: Vec = Vec::new(); + + // Insert flag declarations. + new_stmts.push(has_returned_decl_stmt); + new_stmts.push(ret_val_decl_stmt); + + let mut seen_return_bearing_stmt = false; + + for (index, &stmt_id) in original_stmts.iter().enumerate() { + let has_return_in_while = match &package.get_stmt(stmt_id).kind { + StmtKind::Expr(e) | StmtKind::Semi(e) => contains_return_in_while_expr(package, *e), + _ => false, + }; + let has_return = contains_return_in_stmt(package, stmt_id); + let is_final_trailing_expr = index == original_stmts.len() - 1 + && matches!(package.get_stmt(stmt_id).kind, StmtKind::Expr(_)); + + if has_return_in_while { + // Transform the while loop (conjoins `not __has_returned` onto + // the condition and rewrites Returns in its body via the flag + // slot). + transform_while_stmt( + package, + assigner, + package_id, + stmt_id, + has_returned_var_id, + ret_val_var_id, + udt_pure_tys, + ); + new_stmts.push(stmt_id); + seen_return_bearing_stmt = true; + } else if has_return && !seen_return_bearing_stmt { + // First return-bearing non-while statement. The flag is known + // to be `false` on entry so no guard is needed here; rewriting + // the returns in place to flag assignments is sufficient. + replace_returns_with_flags( + package, + assigner, + package_id, + stmt_id, + has_returned_var_id, + ret_val_var_id, + udt_pure_tys, + ); + new_stmts.push(stmt_id); + seen_return_bearing_stmt = true; + } else if has_return { + // Subsequent return-bearing statement after another + // return-bearing statement has already fired. Rewrite returns, + // then guard the whole statement so it is skipped when the + // earlier return already set `__has_returned`. + replace_returns_with_flags( + package, + assigner, + package_id, + stmt_id, + has_returned_var_id, + ret_val_var_id, + udt_pure_tys, + ); + let guarded = guard_stmt_with_flag( + package, + assigner, + package_id, + stmt_id, + has_returned_var_id, + udt_pure_tys, + ); + new_stmts.push(guarded); + } else if seen_return_bearing_stmt && is_final_trailing_expr { + // Preserve the original trailing value so the final flag check + // can return it from the else branch instead of discarding it + // as a guarded semicolon statement. + new_stmts.push(stmt_id); + } else if seen_return_bearing_stmt { + // Guard continuation statements that follow a return-bearing + // statement so they are skipped once the flag is set. Release + // calls are ordinary side effects here; no-hoist raw wrappers + // keep path-local releases on the returning paths. + let guarded = guard_stmt_with_flag( + package, + assigner, + package_id, + stmt_id, + has_returned_var_id, + udt_pure_tys, + ); + new_stmts.push(guarded); + } else { + new_stmts.push(stmt_id); + } + } + + // Create trailing expression: if __has_returned { __ret_val } else { } + let trailing = create_flag_trailing_expr( + package, + assigner, + &mut new_stmts, + has_returned_var_id, + ret_val_var_id, + return_ty, + ); + + if let Some(trailing_stmt) = trailing { + new_stmts.push(trailing_stmt); + } + + let block = package.blocks.get_mut(block_id).expect("block not found"); + block.stmts = new_stmts; + block.ty = return_ty.clone(); + + // Apply if-else lifting to handle any remaining non-while returns. + transform_block_if_else(package, assigner, block_id, return_ty); +} + +/// Post-transform simplification pass that folds trivial flag patterns. +/// +/// After the flag-based transform, the output can contain redundant +/// if-expressions whose branches are structurally identical: +/// +/// ```text +/// if __has_returned { x } else { x } → x +/// ``` +/// +/// This pass walks the block's statements and trailing expression, folding +/// such identity patterns. Only clearly safe, semantics-preserving folds +/// are applied. This is the structured-IR analog of LLVM's `SimplifyCFG` +/// running after `mergereturn`. +fn simplify_flag_patterns(package: &mut Package, block_id: BlockId) { + let stmts = package.get_block(block_id).stmts.clone(); + for &stmt_id in &stmts { + simplify_flag_patterns_in_stmt(package, stmt_id); + } +} + +/// Simplify flag patterns within a single statement. +fn simplify_flag_patterns_in_stmt(package: &mut Package, stmt_id: StmtId) { + let expr_id = match &package.get_stmt(stmt_id).kind { + StmtKind::Expr(e) | StmtKind::Semi(e) | StmtKind::Local(_, _, e) => *e, + StmtKind::Item(_) => return, + }; + if let Some(replacement) = try_fold_identical_branches(package, expr_id) { + let stmt = package.stmts.get_mut(stmt_id).expect("stmt not found"); + match &mut stmt.kind { + StmtKind::Expr(e) | StmtKind::Semi(e) | StmtKind::Local(_, _, e) => { + *e = replacement; + } + StmtKind::Item(_) => {} + } + } +} + +/// If `expr_id` is an `If(cond, then_expr, Some(else_expr))` where the +/// then and else branches are structurally identical, return the branch +/// expression id to replace the if with. Returns `None` otherwise. +fn try_fold_identical_branches(package: &Package, expr_id: ExprId) -> Option { + let expr = package.get_expr(expr_id); + let ExprKind::If(_, then_id, Some(else_id)) = &expr.kind else { + return None; + }; + if exprs_structurally_equal(package, *then_id, *else_id) { + Some(*then_id) + } else { + None + } +} + +/// Recursively compare two expression trees for structural equality. +/// +/// Two expressions are structurally equal when their `ExprKind` variants +/// match and all recursive children are structurally equal. Span and +/// exec-graph metadata are ignored; only the semantic shape matters. +/// +/// This is intentionally conservative: any unknown or complex pattern +/// returns `false` to avoid incorrect folding. +fn exprs_structurally_equal(package: &Package, a: ExprId, b: ExprId) -> bool { + if a == b { + return true; + } + let ea = package.get_expr(a); + let eb = package.get_expr(b); + if ea.ty != eb.ty { + return false; + } + match (&ea.kind, &eb.kind) { + (ExprKind::Var(res_a, args_a), ExprKind::Var(res_b, args_b)) => { + res_a == res_b && args_a == args_b + } + (ExprKind::Lit(lit_a), ExprKind::Lit(lit_b)) => lit_a == lit_b, + (ExprKind::Tuple(elems_a), ExprKind::Tuple(elems_b)) => { + elems_a.len() == elems_b.len() + && elems_a + .iter() + .zip(elems_b.iter()) + .all(|(&a, &b)| exprs_structurally_equal(package, a, b)) + } + (ExprKind::Block(bid_a), ExprKind::Block(bid_b)) => { + blocks_structurally_equal(package, *bid_a, *bid_b) + } + (ExprKind::UnOp(op_a, operand_a), ExprKind::UnOp(op_b, operand_b)) => { + op_a == op_b && exprs_structurally_equal(package, *operand_a, *operand_b) + } + (ExprKind::BinOp(op_a, l_a, r_a), ExprKind::BinOp(op_b, l_b, r_b)) => { + op_a == op_b + && exprs_structurally_equal(package, *l_a, *l_b) + && exprs_structurally_equal(package, *r_a, *r_b) + } + (ExprKind::If(c_a, t_a, e_a), ExprKind::If(c_b, t_b, e_b)) => { + exprs_structurally_equal(package, *c_a, *c_b) + && exprs_structurally_equal(package, *t_a, *t_b) + && match (e_a, e_b) { + (Some(ea), Some(eb)) => exprs_structurally_equal(package, *ea, *eb), + (None, None) => true, + _ => false, + } + } + (ExprKind::Array(a_elems), ExprKind::Array(b_elems)) + | (ExprKind::ArrayLit(a_elems), ExprKind::ArrayLit(b_elems)) => { + a_elems.len() == b_elems.len() + && a_elems + .iter() + .zip(b_elems.iter()) + .all(|(&a, &b)| exprs_structurally_equal(package, a, b)) + } + // Conservative: anything else is considered non-equal. + _ => false, + } +} + +/// Recursively compare two blocks for structural equality. +fn blocks_structurally_equal(package: &Package, a: BlockId, b: BlockId) -> bool { + if a == b { + return true; + } + let ba = package.get_block(a); + let bb = package.get_block(b); + if ba.ty != bb.ty || ba.stmts.len() != bb.stmts.len() { + return false; + } + ba.stmts + .iter() + .zip(bb.stmts.iter()) + .all(|(&sa, &sb)| stmts_structurally_equal(package, sa, sb)) +} + +/// Recursively compare two statements for structural equality. +fn stmts_structurally_equal(package: &Package, a: StmtId, b: StmtId) -> bool { + if a == b { + return true; + } + let sa = package.get_stmt(a); + let sb = package.get_stmt(b); + match (&sa.kind, &sb.kind) { + (StmtKind::Expr(ea), StmtKind::Expr(eb)) | (StmtKind::Semi(ea), StmtKind::Semi(eb)) => { + exprs_structurally_equal(package, *ea, *eb) + } + (StmtKind::Local(m_a, p_a, e_a), StmtKind::Local(m_b, p_b, e_b)) => { + m_a == m_b && p_a == p_b && exprs_structurally_equal(package, *e_a, *e_b) + } + _ => false, + } +} + +/// Rewrite a while-loop statement under the flag-based transform. +/// +/// Delegates to [`transform_while_in_expr`] on the statement's inner +/// expression; descends through `Block` and `If` wrappers so for-loop +/// desugarings (which wrap the while in a block) are handled. +/// +/// ```text +/// // Before +/// while cond { body } +/// +/// // After +/// while not __has_returned and cond { body' } +/// // where body' has all `return v` replaced by +/// // { __ret_val = v; __has_returned = true; } +/// ``` +#[allow(clippy::too_many_arguments)] +fn transform_while_stmt( + package: &mut Package, + assigner: &mut Assigner, + package_id: PackageId, + stmt_id: StmtId, + has_returned_var_id: LocalVarId, + ret_val_var_id: LocalVarId, + udt_pure_tys: &UdtPureTyCache, +) { + let expr_id = match &package.get_stmt(stmt_id).kind { + StmtKind::Expr(e) | StmtKind::Semi(e) => *e, + _ => return, + }; + + transform_while_in_expr( + package, + assigner, + package_id, + expr_id, + has_returned_var_id, + ret_val_var_id, + udt_pure_tys, + ); +} + +/// Walk an expression tree, locate every `ExprKind::While` that transitively +/// contains a return, and rewrite it for the flag-based transform. +/// +/// For each such `While`: +/// * Conjoins `not __has_returned` onto the loop condition. +/// * Calls [`replace_returns_in_block`] to rewrite `return v` inside the +/// body as flag-assignment blocks. +/// +/// Descends into `Block` and `If` wrappers so nested structures (including +/// for-loop desugarings) are handled. +/// +/// ```text +/// // Before +/// while cond { ...; if guard { return v; }; ... } +/// +/// // After +/// while not __has_returned and cond { +/// ...; +/// if guard { +/// __ret_val = v; +/// __has_returned = true; +/// }; +/// ... +/// } +/// ``` +#[allow(clippy::too_many_arguments)] +fn transform_while_in_expr( + package: &mut Package, + assigner: &mut Assigner, + package_id: PackageId, + expr_id: ExprId, + has_returned_var_id: LocalVarId, + ret_val_var_id: LocalVarId, + udt_pure_tys: &UdtPureTyCache, +) { + let expr = package.get_expr(expr_id).clone(); + match &expr.kind { + ExprKind::While(cond_id, body_block_id) => { + let cond_id = *cond_id; + let body_block_id = *body_block_id; + + if contains_return_in_expr(package, cond_id) { + replace_returns_in_condition_expr( + package, + assigner, + package_id, + cond_id, + has_returned_var_id, + ret_val_var_id, + udt_pure_tys, + ); + } + + // Conjoin !__has_returned with the while condition. + // LHS must be the flag guard so that AndL short-circuits and + // skips the original condition once a return has fired. + let not_flag = create_not_var_expr(package, assigner, has_returned_var_id); + let new_cond = { + let op = BinOp::AndL; + let ty: &Ty = &Ty::Prim(Prim::Bool); + alloc_bin_op_expr( + package, + assigner, + op, + not_flag, + cond_id, + ty.clone(), + Span::default(), + ) + }; + + // Replace returns inside the body. + if contains_return_in_block(package, body_block_id) { + replace_returns_in_block( + package, + assigner, + package_id, + body_block_id, + has_returned_var_id, + ret_val_var_id, + udt_pure_tys, + ); + } + + // Update the while expression. + let e = package.exprs.get_mut(expr_id).expect("expr not found"); + *e = Expr { + id: expr_id, + span: expr.span, + ty: expr.ty.clone(), + kind: ExprKind::While(new_cond, body_block_id), + exec_graph_range: EMPTY_EXEC_RANGE, + }; + } + ExprKind::Block(block_id) => { + let stmts = package.get_block(*block_id).stmts.clone(); + for &stmt_id in &stmts { + let inner_expr_id = match &package.get_stmt(stmt_id).kind { + StmtKind::Expr(e) | StmtKind::Semi(e) => *e, + _ => continue, + }; + if contains_return_in_while_expr(package, inner_expr_id) { + transform_while_in_expr( + package, + assigner, + package_id, + inner_expr_id, + has_returned_var_id, + ret_val_var_id, + udt_pure_tys, + ); + } + } + } + ExprKind::If(_, then_id, else_opt) => { + if contains_return_in_while_expr(package, *then_id) { + transform_while_in_expr( + package, + assigner, + package_id, + *then_id, + has_returned_var_id, + ret_val_var_id, + udt_pure_tys, + ); + } + if let Some(e) = *else_opt + && contains_return_in_while_expr(package, e) + { + transform_while_in_expr( + package, + assigner, + package_id, + e, + has_returned_var_id, + ret_val_var_id, + udt_pure_tys, + ); + } + } + _ => {} + } +} + +/// Builds the expression that should execute when an `if` condition falls +/// through instead of returning. +/// +/// Before, the original `else` arm and the statements that continue after the +/// `if` live as separate IR fragments. After, they are combined into one +/// expression or block whose internal `Return` nodes have already been lowered +/// to flag writes, letting later normalization treat the entire fallthrough path +/// as a single else branch. +fn create_fallthrough_continuation_expr( + package: &mut Package, + assigner: &mut Assigner, + else_opt: Option, + continuation_stmts: Vec, + return_ty: &Ty, +) -> ExprId { + if let Some(else_expr_id) = else_opt { + strip_returns_from_expr(package, assigner, else_expr_id, return_ty); + if continuation_stmts.is_empty() { + return else_expr_id; + } + + let else_semi = alloc_semi_stmt(package, assigner, else_expr_id, Span::default()); + let mut new_stmts = Vec::with_capacity(continuation_stmts.len() + 1); + new_stmts.push(else_semi); + new_stmts.extend(continuation_stmts); + let block_id = alloc_block( + package, + assigner, + new_stmts, + return_ty.clone(), + Span::default(), + ); + transform_block_if_else(package, assigner, block_id, return_ty); + return alloc_block_expr( + package, + assigner, + block_id, + return_ty.clone(), + Span::default(), + ); + } + + if continuation_stmts.is_empty() { + assert!( + *return_ty == Ty::UNIT, + "fallthrough continuation is empty for non-Unit return type" + ); + } + + let block_id = alloc_block( + package, + assigner, + continuation_stmts, + return_ty.clone(), + Span::default(), + ); + transform_block_if_else(package, assigner, block_id, return_ty); + alloc_block_expr( + package, + assigner, + block_id, + return_ty.clone(), + Span::default(), + ) +} + +/// Walk every statement in a block and rewrite `Return(val)` subexpressions +/// into `{ __ret_val = val; __has_returned = true; }` via +/// [`replace_returns_with_flags`]. +/// +/// After replacement, statements following the first return-bearing +/// statement in the same block are wrapped in `if not __has_returned { … }` +/// guards so they are skipped once the flag fires within the current +/// iteration or scope. +/// +/// ```text +/// // Before +/// { if g { return v; }; stmt2 } +/// +/// // After +/// { if g { { __ret_val = v; __has_returned = true; } }; +/// if not __has_returned { stmt2 }; } +/// ``` +#[allow(clippy::too_many_arguments)] +fn replace_returns_in_block( + package: &mut Package, + assigner: &mut Assigner, + package_id: PackageId, + block_id: BlockId, + has_returned_var_id: LocalVarId, + ret_val_var_id: LocalVarId, + udt_pure_tys: &UdtPureTyCache, +) { + let stmts = package.get_block(block_id).stmts.clone(); + + // Identify the first statement carrying a return *before* any + // replacement so the index is stable. + let first_return_idx = stmts + .iter() + .position(|&sid| contains_return_in_stmt(package, sid)); + + // Replace returns in every statement. + for &stmt_id in &stmts { + replace_returns_with_flags( + package, + assigner, + package_id, + stmt_id, + has_returned_var_id, + ret_val_var_id, + udt_pure_tys, + ); + } + + // Guard subsequent statements so they are skipped once the flag is set. + if let Some(first_idx) = first_return_idx + && first_idx + 1 < stmts.len() + { + let last_idx = stmts.len() - 1; + let is_last_trailing_expr = + matches!(package.get_stmt(stmts[last_idx]).kind, StmtKind::Expr(_)); + + let mut new_stmts: Vec = stmts[..=first_idx].to_vec(); + for (i, &stmt_id) in stmts[first_idx + 1..].iter().enumerate() { + let actual_idx = first_idx + 1 + i; + // Preserve the trailing expression without guarding — its + // value is only consumed when `__has_returned` is false + // (the flag-trailing expression handles the true case). + if actual_idx == last_idx && is_last_trailing_expr { + new_stmts.push(stmt_id); + } else { + let guarded = guard_stmt_with_flag( + package, + assigner, + package_id, + stmt_id, + has_returned_var_id, + udt_pure_tys, + ); + new_stmts.push(guarded); + } + } + let block = package.blocks.get_mut(block_id).expect("block not found"); + block.stmts = new_stmts; + } +} + +/// Rewrite `Return(val)` subexpressions in a single statement's expression +/// tree to the flag-assignment pair. +/// +/// ```text +/// // Before +/// Expr(if cond { return v; }) +/// +/// // After +/// Expr(if cond { { __ret_val = v; __has_returned = true; } }) +/// ``` +#[allow(clippy::too_many_arguments)] +fn replace_returns_with_flags( + package: &mut Package, + assigner: &mut Assigner, + package_id: PackageId, + stmt_id: StmtId, + has_returned_var_id: LocalVarId, + ret_val_var_id: LocalVarId, + udt_pure_tys: &UdtPureTyCache, +) { + let expr_id = match &package.get_stmt(stmt_id).kind { + StmtKind::Expr(e) | StmtKind::Semi(e) | StmtKind::Local(_, _, e) => *e, + StmtKind::Item(_) => return, + }; + replace_returns_in_expr( + package, + assigner, + package_id, + expr_id, + has_returned_var_id, + ret_val_var_id, + udt_pure_tys, + ); + + // Sync Pat type for Local bindings whose initializer type may have + // changed after return replacement (e.g. a Block wrapping an If whose + // else branch was replaced with a Unit-typed flag-assignment block). + if let StmtKind::Local(_, pat_id, init_id) = &package.get_stmt(stmt_id).kind { + let pat_id = *pat_id; + let init_id = *init_id; + let init_ty = package.get_expr(init_id).ty.clone(); + let pat = package.pats.get_mut(pat_id).expect("pat not found"); + pat.ty = init_ty; + } +} + +/// Rewrite `Return(val)` nodes inside an expression tree to Unit-typed flag-assignment blocks. +/// +/// Each `Return(val)` expression is replaced in place with: +/// +/// ```text +/// { __ret_val = val; __has_returned = true; } : () +/// ``` +/// +/// # Before +/// ```text +/// if cond { return v; } +/// ``` +/// # After +/// ```text +/// if cond { { __ret_val = v; __has_returned = true; } } +/// ``` +/// # Requires +/// - `expr_id` is valid in `package`. +/// - `has_returned_var_id` and `ret_val_var_id` reference the flag pair +/// introduced by [`transform_block_with_flags`]. +/// +/// # Ensures +/// - Every `ExprKind::Return` reachable through `Block`/`If`/compound +/// descent is replaced with the flag-assignment block. +/// - The outer expression's type becomes Unit at each replacement site; +/// callers guarantee the enclosing loop exits on the next condition check. +/// +/// # Mutations +/// - Rewrites `Expr` nodes in place at each Return replacement site. +/// - Recurses into nested blocks via [`replace_returns_in_block`]. +/// - Allocates new FIR nodes through `assigner`. +#[allow(clippy::too_many_arguments)] +#[allow(clippy::too_many_lines)] +fn replace_returns_in_expr( + package: &mut Package, + assigner: &mut Assigner, + package_id: PackageId, + expr_id: ExprId, + has_returned_var_id: LocalVarId, + ret_val_var_id: LocalVarId, + udt_pure_tys: &UdtPureTyCache, +) { + let expr = package.get_expr(expr_id).clone(); + match &expr.kind { + ExprKind::Return(inner) => { + let inner_id = *inner; + let inner_ty = package.get_expr(inner_id).ty.clone(); + // Build: { __ret_val = val; __has_returned = true; } + let assign_val = + create_assign_expr(package, assigner, ret_val_var_id, inner_id, &inner_ty); + let assign_val_semi = alloc_semi_stmt(package, assigner, assign_val, Span::default()); + + let true_lit = alloc_bool_lit(package, assigner, true, Span::default()); + let assign_flag = create_assign_expr( + package, + assigner, + has_returned_var_id, + true_lit, + &Ty::Prim(Prim::Bool), + ); + let assign_flag_semi = alloc_semi_stmt(package, assigner, assign_flag, Span::default()); + + let flag_block = { + let stmts = vec![assign_val_semi, assign_flag_semi]; + let ty: &Ty = &Ty::UNIT; + alloc_block(package, assigner, stmts, ty.clone(), Span::default()) + }; + let flag_block_expr = { + let ty: &Ty = &Ty::UNIT; + alloc_block_expr(package, assigner, flag_block, ty.clone(), Span::default()) + }; + + // Replace the Return expression in-place with the block expression. + let replacement = package.get_expr(flag_block_expr).clone(); + let e = package.exprs.get_mut(expr_id).expect("expr not found"); + *e = Expr { + id: expr_id, + span: expr.span, + ty: replacement.ty, + kind: replacement.kind, + exec_graph_range: EMPTY_EXEC_RANGE, + }; + } + ExprKind::Block(block_id) => { + let bid = *block_id; + replace_returns_in_block( + package, + assigner, + package_id, + bid, + has_returned_var_id, + ret_val_var_id, + udt_pure_tys, + ); + // Nested blocks that previously contained a trailing `Return` + // expression may have been typed to the callable return type. + // After replacement the Return is a Unit block, so sync the + // block's type to its trailing expression (Unit when the block + // has no trailing Expr stmt). Also refresh the enclosing + // expression's type since `Block` expressions carry the + // block's type on the `Expr` node. + sync_block_type_to_stmt_or_unit(package, bid); + let new_block_ty = package.get_block(bid).ty.clone(); + let e = package.exprs.get_mut(expr_id).expect("expr not found"); + e.ty = new_block_ty; + } + ExprKind::If(_, then_id, else_opt) => { + let then_id = *then_id; + let else_id = *else_opt; + replace_returns_in_expr( + package, + assigner, + package_id, + then_id, + has_returned_var_id, + ret_val_var_id, + udt_pure_tys, + ); + if let Some(e) = else_id { + replace_returns_in_expr( + package, + assigner, + package_id, + e, + has_returned_var_id, + ret_val_var_id, + udt_pure_tys, + ); + } + // Update the If expression type to reflect branch type changes. + // After return replacement, a branch containing Return is + // replaced with a Unit-typed flag-assignment block. Derive the + // If type from branch types: prefer the non-Unit branch type so + // the surrounding Local binding keeps its original type. + let then_ty = package.get_expr(then_id).ty.clone(); + let new_ty = if let Some(else_id) = else_id { + let else_ty = package.get_expr(else_id).ty.clone(); + if then_ty == Ty::UNIT { + else_ty + } else { + then_ty + } + } else { + then_ty + }; + let e = package.exprs.get_mut(expr_id).expect("expr not found"); + e.ty = new_ty; + } + // Audit: Only `Block` and `If` arms above require + // post-recursion type synchronization (their enclosing `Expr` + // carries the inner expression's type, which shifts to `Unit` + // when a trailing `Return` is replaced). All remaining arms + // below are defensive — `Return` cannot legitimately nest in + // these positions in valid normalized FIR. Recursive replacement here + // keeps the pass robust by construction. + ExprKind::Array(exprs) | ExprKind::ArrayLit(exprs) | ExprKind::Tuple(exprs) => { + let ids: Vec = exprs.clone(); + for e in ids { + replace_returns_in_expr( + package, + assigner, + package_id, + e, + has_returned_var_id, + ret_val_var_id, + udt_pure_tys, + ); + } + } + ExprKind::ArrayRepeat(a, b) + | ExprKind::Assign(a, b) + | ExprKind::AssignOp(_, a, b) + | ExprKind::BinOp(_, a, b) + | ExprKind::Call(a, b) + | ExprKind::Index(a, b) + | ExprKind::AssignField(a, _, b) + | ExprKind::UpdateField(a, _, b) => { + let (a_id, b_id) = (*a, *b); + replace_returns_in_expr( + package, + assigner, + package_id, + a_id, + has_returned_var_id, + ret_val_var_id, + udt_pure_tys, + ); + replace_returns_in_expr( + package, + assigner, + package_id, + b_id, + has_returned_var_id, + ret_val_var_id, + udt_pure_tys, + ); + } + ExprKind::AssignIndex(a, b, c) | ExprKind::UpdateIndex(a, b, c) => { + let (a_id, b_id, c_id) = (*a, *b, *c); + replace_returns_in_expr( + package, + assigner, + package_id, + a_id, + has_returned_var_id, + ret_val_var_id, + udt_pure_tys, + ); + replace_returns_in_expr( + package, + assigner, + package_id, + b_id, + has_returned_var_id, + ret_val_var_id, + udt_pure_tys, + ); + replace_returns_in_expr( + package, + assigner, + package_id, + c_id, + has_returned_var_id, + ret_val_var_id, + udt_pure_tys, + ); + } + ExprKind::Fail(e) | ExprKind::Field(e, _) | ExprKind::UnOp(_, e) => { + let sub = *e; + replace_returns_in_expr( + package, + assigner, + package_id, + sub, + has_returned_var_id, + ret_val_var_id, + udt_pure_tys, + ); + } + ExprKind::Range(start, step, end) => { + let ids: Vec = [start, step, end].into_iter().flatten().copied().collect(); + for e in ids { + replace_returns_in_expr( + package, + assigner, + package_id, + e, + has_returned_var_id, + ret_val_var_id, + udt_pure_tys, + ); + } + } + ExprKind::Struct(_, copy, fields) => { + let copy_id = *copy; + let field_ids: Vec = fields.iter().map(|fa| fa.value).collect(); + if let Some(c) = copy_id { + replace_returns_in_expr( + package, + assigner, + package_id, + c, + has_returned_var_id, + ret_val_var_id, + udt_pure_tys, + ); + } + for e in field_ids { + replace_returns_in_expr( + package, + assigner, + package_id, + e, + has_returned_var_id, + ret_val_var_id, + udt_pure_tys, + ); + } + } + ExprKind::String(components) => { + let ids: Vec = components + .iter() + .filter_map(|c| match c { + qsc_fir::fir::StringComponent::Expr(e) => Some(*e), + qsc_fir::fir::StringComponent::Lit(_) => None, + }) + .collect(); + for e in ids { + replace_returns_in_expr( + package, + assigner, + package_id, + e, + has_returned_var_id, + ret_val_var_id, + udt_pure_tys, + ); + } + } + ExprKind::While(cond, body) => { + let (cond_id, body_id) = (*cond, *body); + if contains_return_in_block(package, body_id) + || contains_return_in_expr(package, cond_id) + { + // Delegate to `transform_while_in_expr` so the nested + // while's condition gets conjoined with `not __has_returned` + // and its body returns are rewritten, matching the + // top-level while handling. Without this, a nested while + // whose only exit is the return would loop forever after + // the return-to-flag rewrite. + transform_while_in_expr( + package, + assigner, + package_id, + expr_id, + has_returned_var_id, + ret_val_var_id, + udt_pure_tys, + ); + } else { + // No returns reachable through this while; structural + // recursion into the condition is sufficient (the body is + // return-free so walking it is a no-op). + replace_returns_in_expr( + package, + assigner, + package_id, + cond_id, + has_returned_var_id, + ret_val_var_id, + udt_pure_tys, + ); + } + } + ExprKind::Closure(_, _) | ExprKind::Hole | ExprKind::Lit(_) | ExprKind::Var(_, _) => {} + } +} + +/// Rewrites return-bearing while-condition subexpressions to preserve +/// Bool-typed condition semantics under the flag strategy. +/// +/// A condition-side `Return(v)` becomes: +/// `{ __ret_val = v; __has_returned = true; false }` +/// so the loop condition evaluates to false immediately after capturing +/// the return value. +#[allow(clippy::too_many_arguments)] +#[allow(clippy::too_many_lines)] +fn replace_returns_in_condition_expr( + package: &mut Package, + assigner: &mut Assigner, + package_id: PackageId, + expr_id: ExprId, + has_returned_var_id: LocalVarId, + ret_val_var_id: LocalVarId, + udt_pure_tys: &UdtPureTyCache, +) { + let expr = package.get_expr(expr_id).clone(); + match &expr.kind { + ExprKind::Return(inner_id) => { + replace_condition_return_with_flags( + package, + assigner, + expr_id, + expr.span, + *inner_id, + has_returned_var_id, + ret_val_var_id, + ); + } + ExprKind::Block(block_id) => { + let bid = *block_id; + let stmts = package.get_block(bid).stmts.clone(); + let last_stmt = stmts.last().copied(); + + for stmt_id in stmts { + let expr_ids: Vec = { + let stmt = package.get_stmt(stmt_id); + match &stmt.kind { + StmtKind::Expr(e) | StmtKind::Semi(e) | StmtKind::Local(_, _, e) => { + vec![*e] + } + StmtKind::Item(_) => vec![], + } + }; + + for e in expr_ids { + if Some(stmt_id) == last_stmt + && matches!(package.get_stmt(stmt_id).kind, StmtKind::Expr(_)) + { + replace_returns_in_condition_expr( + package, + assigner, + package_id, + e, + has_returned_var_id, + ret_val_var_id, + udt_pure_tys, + ); + } else { + replace_returns_in_expr( + package, + assigner, + package_id, + e, + has_returned_var_id, + ret_val_var_id, + udt_pure_tys, + ); + } + } + } + + sync_block_type_to_stmt_or_unit(package, bid); + let new_block_ty = package.get_block(bid).ty.clone(); + let e = package.exprs.get_mut(expr_id).expect("expr not found"); + e.ty = new_block_ty; + } + ExprKind::If(cond_id, then_id, else_opt) => { + replace_returns_in_condition_expr( + package, + assigner, + package_id, + *cond_id, + has_returned_var_id, + ret_val_var_id, + udt_pure_tys, + ); + replace_returns_in_condition_expr( + package, + assigner, + package_id, + *then_id, + has_returned_var_id, + ret_val_var_id, + udt_pure_tys, + ); + if let Some(e) = else_opt { + replace_returns_in_condition_expr( + package, + assigner, + package_id, + *e, + has_returned_var_id, + ret_val_var_id, + udt_pure_tys, + ); + } + } + ExprKind::BinOp(BinOp::AndL | BinOp::OrL, lhs, rhs) => { + replace_returns_in_condition_expr( + package, + assigner, + package_id, + *lhs, + has_returned_var_id, + ret_val_var_id, + udt_pure_tys, + ); + replace_returns_in_condition_expr( + package, + assigner, + package_id, + *rhs, + has_returned_var_id, + ret_val_var_id, + udt_pure_tys, + ); + } + ExprKind::UnOp(UnOp::NotL, inner_id) => { + replace_returns_in_condition_expr( + package, + assigner, + package_id, + *inner_id, + has_returned_var_id, + ret_val_var_id, + udt_pure_tys, + ); + } + _ => { + assert!( + !contains_return_in_expr(package, expr_id), + "unexpected return-bearing while-condition shape after normalize" + ); + } + } +} + +/// Rewrites a `Return(inner_id)` that appears inside a condition expression into +/// a block that records the return and yields `false`. +/// +/// Before, evaluating the condition exits the callable directly. After, the +/// enclosing expression tree stays well-typed as `Bool`, but the block stores +/// the return value in `ret_val_var_id`, sets `has_returned_var_id`, and leaves +/// later guards to skip the rest of the computation. +fn replace_condition_return_with_flags( + package: &mut Package, + assigner: &mut Assigner, + return_expr_id: ExprId, + span: Span, + inner_id: ExprId, + has_returned_var_id: LocalVarId, + ret_val_var_id: LocalVarId, +) { + let inner_ty = package.get_expr(inner_id).ty.clone(); + let assign_val = create_assign_expr(package, assigner, ret_val_var_id, inner_id, &inner_ty); + let assign_val_semi = alloc_semi_stmt(package, assigner, assign_val, Span::default()); + + let true_lit = alloc_bool_lit(package, assigner, true, Span::default()); + let assign_flag = create_assign_expr( + package, + assigner, + has_returned_var_id, + true_lit, + &Ty::Prim(Prim::Bool), + ); + let assign_flag_semi = alloc_semi_stmt(package, assigner, assign_flag, Span::default()); + + // Condition contexts still need a boolean value after the return is lowered + // into side-effecting flag writes. + let false_lit = alloc_bool_lit(package, assigner, false, Span::default()); + let false_stmt = alloc_expr_stmt(package, assigner, false_lit, Span::default()); + + let flag_block = { + let stmts = vec![assign_val_semi, assign_flag_semi, false_stmt]; + let ty: &Ty = &Ty::Prim(Prim::Bool); + alloc_block(package, assigner, stmts, ty.clone(), Span::default()) + }; + let flag_block_expr = { + let ty: &Ty = &Ty::Prim(Prim::Bool); + alloc_block_expr(package, assigner, flag_block, ty.clone(), Span::default()) + }; + + let replacement = package.get_expr(flag_block_expr).clone(); + let e = package + .exprs + .get_mut(return_expr_id) + .expect("expr not found"); + *e = Expr { + id: return_expr_id, + span, + ty: replacement.ty, + kind: replacement.kind, + exec_graph_range: EMPTY_EXEC_RANGE, + }; +} + +/// Wrap a statement so it is skipped when `__has_returned` is already set. +/// +/// # Before +/// ```text +/// stmt; +/// ``` +/// # After (Semi / Item / Unit-typed Expr statements) +/// ```text +/// if not __has_returned { stmt }; +/// ``` +/// # After (Local statements) +/// ```text +/// // `let x : T = init;` becomes: +/// let x : T = if not __has_returned { init } else { default(T) }; +/// ``` +/// For `Local` statements, wrapping the whole statement in an `if` block +/// would scope the binding away from subsequent statements that reference +/// it (a real bug if an earlier rewrite lifts a trailing value into a +/// `let @generated_ident_N = ...` that is then referenced from the +/// block's final `if __has_returned { __ret_val } else { @generated_ident_N }` +/// expression). Instead, the initializer is rewritten to a conditional +/// expression and the `Local` statement itself stays at the outer scope, +/// preserving the binding's visibility. +/// +/// # Requires +/// - `stmt_id` is valid in `package`. +/// - `has_returned_var_id` is the flag introduced by +/// [`transform_block_with_flags`]. +/// - For `Local` statements, the initializer's type has a classical +/// default reachable through [`create_default_value`]. +/// +/// # Ensures +/// - For non-`Local` statements, returns a new `Semi` statement whose +/// expression guards execution of `stmt_id` on `not __has_returned`. +/// - For `Local` statements, mutates the statement's initializer in place +/// and returns the original `stmt_id` unchanged. +/// - The original statement's effects execute only when no prior +/// flag-based return has fired. +/// +/// # Mutations +/// - Allocates the guard block, `Var`/`Not` expressions, `If` expression, +/// and wrapping `Semi` statement through `assigner`. +/// - For `Local` statements, rewrites `package.stmts[stmt_id].kind` in +/// place to reference a new guarded-initializer `ExprId`. +fn guard_stmt_with_flag( + package: &mut Package, + assigner: &mut Assigner, + package_id: PackageId, + stmt_id: StmtId, + has_returned_var_id: LocalVarId, + udt_pure_tys: &UdtPureTyCache, +) -> StmtId { + // `Local` statements require special handling: wrapping the whole + // statement in `if not __has_returned { let x = init; }` would hide + // `x` from subsequent statements that reference it. Instead, rewrite + // the initializer to `if not __has_returned { init } else { default }` + // and leave the `Local` at the outer scope. + if let StmtKind::Local(mutability, pat_id, init_expr_id) = package.get_stmt(stmt_id).kind { + let init_ty = package.get_expr(init_expr_id).ty.clone(); + let default_val = require_classical_default( + package, + assigner, + package_id, + &init_ty, + udt_pure_tys, + UnsupportedDefaultSite::GuardedLocalInitializer, + ); + + let not_flag = create_not_var_expr(package, assigner, has_returned_var_id); + + let then_trailing = alloc_expr_stmt(package, assigner, init_expr_id, Span::default()); + let then_block = { + let stmts = vec![then_trailing]; + let ty: &Ty = &init_ty; + alloc_block(package, assigner, stmts, ty.clone(), Span::default()) + }; + let then_expr = { + let ty: &Ty = &init_ty; + alloc_block_expr(package, assigner, then_block, ty.clone(), Span::default()) + }; + + let else_trailing = alloc_expr_stmt(package, assigner, default_val, Span::default()); + let else_block = { + let stmts = vec![else_trailing]; + let ty: &Ty = &init_ty; + alloc_block(package, assigner, stmts, ty.clone(), Span::default()) + }; + let else_expr = { + let ty: &Ty = &init_ty; + alloc_block_expr(package, assigner, else_block, ty.clone(), Span::default()) + }; + + let if_expr = { + let else_expr = Some(else_expr); + let ty: &Ty = &init_ty; + alloc_if_expr( + package, + assigner, + not_flag, + then_expr, + else_expr, + ty.clone(), + Span::default(), + ) + }; + + let stmt = package.stmts.get_mut(stmt_id).expect("stmt not found"); + stmt.kind = StmtKind::Local(mutability, pat_id, if_expr); + return stmt_id; + } + + assert!( + match &package.get_stmt(stmt_id).kind { + StmtKind::Semi(_) | StmtKind::Item(_) => true, + StmtKind::Expr(e) => package.get_expr(*e).ty == Ty::UNIT, + StmtKind::Local(_, _, _) => unreachable!("Local handled above"), + }, + "guard_stmt_with_flag requires Unit-typed inner stmt" + ); + let not_flag = create_not_var_expr(package, assigner, has_returned_var_id); + let guard_block = { + let stmts = vec![stmt_id]; + let ty: &Ty = &Ty::UNIT; + alloc_block(package, assigner, stmts, ty.clone(), Span::default()) + }; + let guard_block_expr = { + let ty: &Ty = &Ty::UNIT; + alloc_block_expr(package, assigner, guard_block, ty.clone(), Span::default()) + }; + let if_expr = { + let ty: &Ty = &Ty::UNIT; + alloc_if_expr( + package, + assigner, + not_flag, + guard_block_expr, + None, + ty.clone(), + Span::default(), + ) + }; + alloc_semi_stmt(package, assigner, if_expr, Span::default()) +} + +/// Synthesize the trailing expression that finalizes the flag-based +/// transform, using `__has_returned` to select between the captured return +/// value and the block's original trailing value. +/// +/// When the last statement in `stmts` is a trailing expression (`Expr`, +/// not `Semi`), it is popped and reused as the else branch. Otherwise the +/// else branch is `()` (the return type must be Unit in that case). +/// +/// The trailing expression is first bound to a local variable +/// (`__trailing_result`) before the flag check, ensuring that any flag +/// assignments inside the trailing expression evaluate before the +/// `__has_returned` condition is tested. This prevents the temporal ordering +/// temporal ordering violation. +/// +/// ```text +/// // stmts ends in: ...; original_trailing +/// // Result appended: +/// let __trailing_result : T = original_trailing; +/// if __has_returned { __ret_val } else { __trailing_result } +/// +/// // stmts ends in: ...; side_effect; +/// // Result appended: +/// if __has_returned { __ret_val } else { () } +/// ``` +fn create_flag_trailing_expr( + package: &mut Package, + assigner: &mut Assigner, + stmts: &mut Vec, + has_returned_var_id: LocalVarId, + ret_val_var_id: LocalVarId, + return_ty: &Ty, +) -> Option { + // Check if the last statement is a value-producing trailing expression + // for this callable, not just any expression statement. The flag rewrite + // can turn all-returning non-Unit blocks into Unit expression statements; + // those must not be rebound as `__trailing_result : T`. + let trailing_expr = stmts.last().and_then(|&stmt_id| { + if let StmtKind::Expr(expr_id) = package.get_stmt(stmt_id).kind + && package.get_expr(expr_id).ty == *return_ty + { + Some(expr_id) + } else { + None + } + }); + + let flag_var = { + let ty: &Ty = &Ty::Prim(Prim::Bool); + alloc_local_var_expr( + package, + assigner, + has_returned_var_id, + ty.clone(), + Span::default(), + ) + }; + let ret_var = alloc_local_var_expr( + package, + assigner, + ret_val_var_id, + return_ty.clone(), + Span::default(), + ); + + if let Some(original_trailing) = trailing_expr { + // Pop the trailing expr and bind it to a local before the flag check. + // This ensures that any flag assignments inside the trailing expression + // evaluate before the `__has_returned` condition is tested. + stmts.pop().expect("stmts should not be empty"); + + // let __trailing_result : T = original_trailing; + let (trailing_var_id, trailing_decl_stmt) = { + let mutability = Mutability::Immutable; + alloc_local_var( + package, + assigner, + "__trailing_result", + return_ty, + original_trailing, + mutability, + ) + }; + stmts.push(trailing_decl_stmt); + + // if __has_returned { __ret_val } else { __trailing_result } + let trailing_var_expr = alloc_local_var_expr( + package, + assigner, + trailing_var_id, + return_ty.clone(), + Span::default(), + ); + let if_expr = { + let else_expr = Some(trailing_var_expr); + alloc_if_expr( + package, + assigner, + flag_var, + ret_var, + else_expr, + return_ty.clone(), + Span::default(), + ) + }; + Some(alloc_expr_stmt(package, assigner, if_expr, Span::default())) + } else { + // No fallthrough value survives. Unit returns can keep the previous + // explicit `()` fallback. For non-Unit returns, use the initialized + // return slot on the unreachable false branch to keep the FIR typed. + let fallback_expr = if return_ty == &Ty::UNIT { + alloc_unit_expr(package, assigner, Span::default()) + } else { + alloc_local_var_expr( + package, + assigner, + ret_val_var_id, + return_ty.clone(), + Span::default(), + ) + }; + let if_expr = { + let else_expr = Some(fallback_expr); + alloc_if_expr( + package, + assigner, + flag_var, + ret_var, + else_expr, + return_ty.clone(), + Span::default(), + ) + }; + Some(alloc_expr_stmt(package, assigner, if_expr, Span::default())) + } +} + +/// Check whether a type has a synthesizable classical default value without +/// allocating any FIR nodes. +/// +/// Returns `true` for types that [`create_default_value`] would succeed on, +/// `false` for types (like `Qubit`) that have no classical default. Used by +/// [`unify_returns`] to emit a user-facing error before entering the flag +/// strategy, avoiding a panic in [`require_classical_default`]. +fn can_create_classical_default(ty: &Ty, udt_pure_tys: &UdtPureTyCache) -> bool { + match ty { + Ty::Prim( + Prim::Bool + | Prim::Int + | Prim::BigInt + | Prim::Double + | Prim::Pauli + | Prim::Result + | Prim::String + | Prim::Range + | Prim::RangeFrom + | Prim::RangeTo + | Prim::RangeFull, + ) + | Ty::Array(_) => true, + Ty::Tuple(elems) => elems + .iter() + .all(|e| can_create_classical_default(e, udt_pure_tys)), + Ty::Udt(Res::Item(item_id)) => udt_pure_tys + .get(&(item_id.package, item_id.item)) + .is_some_and(|pure_ty| can_create_classical_default(pure_ty, udt_pure_tys)), + Ty::Arrow(arrow) => { + can_create_classical_default(&arrow.output, udt_pure_tys) + && matches!(arrow.functors, qsc_fir::ty::FunctorSet::Value(_)) + } + Ty::Infer(_) | Ty::Param(_) | Ty::Err | Ty::Prim(Prim::Qubit) | Ty::Udt(_) => false, + } +} + +#[derive(Clone, Copy, Debug)] +enum UnsupportedDefaultSite { + ReturnSlot, + GuardedLocalInitializer, +} + +impl UnsupportedDefaultSite { + fn description(self) -> &'static str { + match self { + Self::ReturnSlot => "flag-strategy return-slot (__ret_val) initialization", + Self::GuardedLocalInitializer => "flag-strategy guarded Local initializer", + } + } +} + +/// Enforces the unsupported-default policy for flag-strategy synthesis sites. +/// +/// The `create_default_value*` helpers intentionally return `Option` so callers +/// can decide policy. For return unification's flag strategy, missing defaults +/// are an internal compiler-contract violation and must fail loudly with a +/// stable, site-specific panic message. +/// +/// **Note:** the known user-reachable case (Qubit-return-in-loop) is now +/// caught earlier by [`can_create_classical_default`] in [`unify_returns`], +/// which emits a user-facing [`Error::UnsupportedLoopReturnType`] diagnostic. +/// This panic remains as a safety net for unforeseen cases. +fn require_classical_default( + package: &mut Package, + assigner: &mut Assigner, + package_id: PackageId, + ty: &Ty, + udt_pure_tys: &UdtPureTyCache, + site: UnsupportedDefaultSite, +) -> ExprId { + create_default_value(package, assigner, package_id, ty, udt_pure_tys).unwrap_or_else(|| { + panic!( + "return_unify unsupported-default contract violation: {} requires a classical default, but `{ty}` has none", + site.description(), + ) + }) +} + +/// Create a default value expression for a type, used to initialize `__ret_val`. +/// +/// The value is never observed: any read of `__ret_val` is guarded by +/// `__has_returned`, which becomes `true` only after an explicit return has +/// written `__ret_val`. Only the type must match. +/// +/// # Before +/// ```text +/// (no expression) +/// ``` +/// # After +/// ```text +/// Expr { ty, kind: default(ty) } // e.g. Lit(Int(0)), Tuple(()), Array([]) +/// ``` +/// # Requires +/// - `ty` has a synthesizable classical default (see +/// [`create_default_value_kind`]); otherwise this returns `None`. +/// +/// # Ensures +/// - Returns `Some(expr_id)` whose `Expr.ty == ty.clone()`. +/// - Returns `None` when no classical default exists (caller surfaces as a +/// deterministic diagnostic rather than emitting malformed FIR). +/// +/// # Mutations +/// - Allocates one fresh `Expr` through `assigner` when `Some`. +fn create_default_value( + package: &mut Package, + assigner: &mut Assigner, + package_id: PackageId, + ty: &Ty, + udt_pure_tys: &UdtPureTyCache, +) -> Option { + let kind = create_default_value_kind(package, assigner, package_id, ty, udt_pure_tys)?; + + let expr_id = assigner.next_expr(); + package.exprs.insert( + expr_id, + Expr { + id: expr_id, + span: Span::default(), + ty: ty.clone(), + kind, + exec_graph_range: EMPTY_EXEC_RANGE, + }, + ); + Some(expr_id) +} + +/// Build a well-typed FIR `ExprKind` for the zero value of `ty`. +/// +/// # Before +/// ```text +/// Ty::{Prim, Tuple, Array, Udt(Res::Item(..)), Arrow(..)} +/// ``` +/// # After +/// ```text +/// ExprKind::{Lit(..), Tuple([defaults..]), Array([]), Var(Res::Item(nop), []), ...} +/// ``` +/// # Requires +/// - `ty` is a type reachable from the callable's return type. +/// - `udt_pure_tys` has been populated from the store. +/// - `package_id` is the id of the package owning `package` — the synthesized +/// nop callable for arrow types is inserted there and referenced through it. +/// +/// # Ensures +/// - Returns `None` when the type has no synthesizable classical default: +/// unresolved types (`Ty::Infer`, `Ty::Param`, `Ty::Err`), qubits +/// (`Prim::Qubit`), UDTs whose pure-ty cache entry is +/// missing or unresolved, and arrow types whose output type itself has no +/// default. +/// - Returns `Some(kind)` whose zero value matches `ty` structurally. +/// - For `Ty::Arrow`, `Some(Var(Res::Item(nop_item), vec![]))` references a +/// newly synthesized nop callable of the same arrow signature; the nop's +/// body returns the output type's default. Any later `Call` on the +/// resulting `__ret_val` value resolves to that nop — correct behavior +/// because the flag guard ensures reads only occur when an explicit return +/// already overwrote `__ret_val` with the real callable. +/// +/// # Mutations +/// - For `Ty::Tuple` and `Ty::Udt` composites, allocates nested default +/// `Expr` nodes through `assigner` via [`create_default_value`]. +/// - For `Ty::Arrow`, inserts a new `Item` (callable) into `package.items` +/// and allocates its body `Pat`, `Block`, and trailing `Expr` / `Stmt`. +fn create_default_value_kind( + package: &mut Package, + assigner: &mut Assigner, + package_id: PackageId, + ty: &Ty, + udt_pure_tys: &UdtPureTyCache, +) -> Option { + match ty { + Ty::Prim(Prim::Bool) => Some(ExprKind::Lit(Lit::Bool(false))), + Ty::Prim(Prim::Int) => Some(ExprKind::Lit(Lit::Int(0))), + Ty::Prim(Prim::BigInt) => Some(ExprKind::Lit(Lit::BigInt(BigInt::from(0)))), + Ty::Prim(Prim::Double) => Some(ExprKind::Lit(Lit::Double(0.0))), + Ty::Prim(Prim::Pauli) => Some(ExprKind::Lit(Lit::Pauli(qsc_fir::fir::Pauli::I))), + Ty::Prim(Prim::Result) => Some(ExprKind::Lit(Lit::Result(Result::Zero))), + Ty::Prim(Prim::String) => Some(ExprKind::String(Vec::new())), + Ty::Tuple(elems) if elems.is_empty() => Some(ExprKind::Tuple(Vec::new())), + Ty::Tuple(elems) => { + let elem_exprs: Vec = elems + .iter() + .map(|elem_ty| { + create_default_value(package, assigner, package_id, elem_ty, udt_pure_tys) + }) + .collect::>()?; + Some(ExprKind::Tuple(elem_exprs)) + } + Ty::Array(_) => Some(ExprKind::Array(Vec::new())), + Ty::Udt(Res::Item(item_id)) => { + let pure_ty = udt_pure_tys.get(&(item_id.package, item_id.item))?.clone(); + create_default_value_kind(package, assigner, package_id, &pure_ty, udt_pure_tys) + } + Ty::Arrow(arrow) => { + create_nop_callable_var(package, assigner, package_id, arrow, udt_pure_tys) + } + Ty::Prim(Prim::Range | Prim::RangeFrom | Prim::RangeTo | Prim::RangeFull) => { + Some(ExprKind::Range(None, None, None)) + } + // No well-typed classical default: unresolved/placeholder types, + // qubits and unresolved UDTs. + Ty::Infer(_) | Ty::Param(_) | Ty::Err | Ty::Prim(Prim::Qubit) | Ty::Udt(_) => None, + } +} + +/// Synthesize a nop callable matching `arrow`'s signature, insert it into +/// `package`, and return a `Var(Res::Item(..))` expression referring to it. +/// +/// The synthesized callable's body is a single-statement block whose +/// trailing expression is the default of the arrow's output type. The input +/// pattern is a typed `Discard`. If the output type itself has no classical +/// default the synthesis is abandoned and `None` is propagated. +/// +/// # Ensures +/// - Returns `Some(Var(Res::Item(ItemId { package: package_id, item: new_item_id }), vec![]))`. +/// - Inserts exactly one new `Item` of kind `Callable` into `package.items`. +/// - The new callable's arrow scheme (input/output/kind/functors) matches +/// `arrow`. +fn create_nop_callable_var( + package: &mut Package, + assigner: &mut Assigner, + package_id: PackageId, + arrow: &qsc_fir::ty::Arrow, + udt_pure_tys: &UdtPureTyCache, +) -> Option { + // Build the nop body's default-of-output trailing expression. + let output_default = + create_default_value(package, assigner, package_id, &arrow.output, udt_pure_tys)?; + let trailing_stmt = alloc_expr_stmt(package, assigner, output_default, Span::default()); + let body_block = { + let stmts = vec![trailing_stmt]; + let ty: &Ty = &arrow.output; + alloc_block(package, assigner, stmts, ty.clone(), Span::default()) + }; + + // Input pattern: a typed Discard matching the arrow's input type. + let input_pat_id = assigner.next_pat(); + package.pats.insert( + input_pat_id, + Pat { + id: input_pat_id, + span: Span::default(), + ty: *arrow.input.clone(), + kind: PatKind::Discard, + }, + ); + + let body_spec = qsc_fir::fir::SpecDecl { + id: assigner.next_node(), + span: Span::default(), + block: body_block, + input: None, + exec_graph: qsc_fir::fir::ExecGraph::default(), + }; + let body_impl = qsc_fir::fir::SpecImpl { + body: body_spec, + adj: None, + ctl: None, + ctl_adj: None, + }; + + // After monomorphization, non-Value functors should not appear in + // reachable return types; surface this as a missing default rather + // than a panic so the pass bails deterministically. + let qsc_fir::ty::FunctorSet::Value(functors) = arrow.functors else { + return None; + }; + + let new_item_id = assigner.next_item(); + let callable_name: Rc = Rc::from(format!("__return_unify_nop_{new_item_id}")); + let decl = CallableDecl { + id: assigner.next_node(), + span: Span::default(), + kind: arrow.kind, + name: Ident { + id: LocalVarId::from(0_u32), + span: Span::default(), + name: callable_name, + }, + generics: Vec::new(), + input: input_pat_id, + output: *arrow.output.clone(), + functors, + implementation: CallableImpl::Spec(body_impl), + attrs: Vec::new(), + }; + + let item = qsc_fir::fir::Item { + id: new_item_id, + span: Span::default(), + parent: None, + doc: Rc::from(""), + attrs: Vec::new(), + visibility: qsc_fir::fir::Visibility::Internal, + kind: ItemKind::Callable(Box::new(decl)), + }; + package.items.insert(new_item_id, item); + + Some(ExprKind::Var( + Res::Item(qsc_fir::fir::ItemId { + package: package_id, + item: new_item_id, + }), + Vec::new(), + )) +} + +/// Create `not Var(__has_returned)`. +fn create_not_var_expr( + package: &mut Package, + assigner: &mut Assigner, + var_id: LocalVarId, +) -> ExprId { + let var = { + let ty: &Ty = &Ty::Prim(Prim::Bool); + alloc_local_var_expr(package, assigner, var_id, ty.clone(), Span::default()) + }; + alloc_not_expr(package, assigner, var, Span::default()) +} + +/// Create `Assign(Var(var_id), value)`. +fn create_assign_expr( + package: &mut Package, + assigner: &mut Assigner, + var_id: LocalVarId, + value: ExprId, + ty: &Ty, +) -> ExprId { + let var_expr = alloc_local_var_expr(package, assigner, var_id, ty.clone(), Span::default()); + alloc_assign_expr(package, assigner, var_expr, value, Span::default()) +} + +/// Create a mutable boolean variable declaration: `mutable name = value`. +/// Returns `(LocalVarId, StmtId)`. +fn create_mutable_bool_var( + package: &mut Package, + assigner: &mut Assigner, + name: &str, + value: bool, +) -> (LocalVarId, StmtId) { + let init_expr = alloc_bool_lit(package, assigner, value, Span::default()); + { + let ty: &Ty = &Ty::Prim(Prim::Bool); + { + let mutability = Mutability::Mutable; + alloc_local_var(package, assigner, name, ty, init_expr, mutability) + } + } +} diff --git a/source/compiler/qsc_fir_transforms/src/return_unify/detect.rs b/source/compiler/qsc_fir_transforms/src/return_unify/detect.rs new file mode 100644 index 0000000000..12516feffa --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/return_unify/detect.rs @@ -0,0 +1,224 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! Exhaustive `Return` detection for the return-unification pass. +//! +//! Mirrors the exhaustive `ExprKind` walker in +//! [`crate::walk_utils`]: every variant is matched explicitly, with no +//! wildcard arm, so adding a new FIR `ExprKind` variant produces a compile +//! error here rather than silently missing a detection site. +//! +//! `ExprKind::Closure` is treated as a leaf: closure captures are +//! [`qsc_fir::fir::LocalVarId`]s rather than expressions, and the closure +//! body lives in a separate callable that `return_unify` visits +//! independently. + +use qsc_fir::fir::{BlockId, ExprId, ExprKind, PackageLookup, StmtId, StmtKind, StringComponent}; + +/// Returns `true` when `block_id` contains any `ExprKind::Return` at any depth. +pub(super) fn contains_return_in_block(lookup: &impl PackageLookup, block_id: BlockId) -> bool { + let block = lookup.get_block(block_id); + block + .stmts + .iter() + .any(|&stmt_id| contains_return_in_stmt(lookup, stmt_id)) +} + +/// Returns `true` when the statement's initializer/expression contains any +/// `ExprKind::Return`. +pub(super) fn contains_return_in_stmt(lookup: &impl PackageLookup, stmt_id: StmtId) -> bool { + let stmt = lookup.get_stmt(stmt_id); + match &stmt.kind { + StmtKind::Expr(expr_id) | StmtKind::Semi(expr_id) => { + contains_return_in_expr(lookup, *expr_id) + } + StmtKind::Local(_, _, expr_id) => contains_return_in_expr(lookup, *expr_id), + StmtKind::Item(_) => false, + } +} + +/// Return `true` when any sub-expression of `expr_id` is an `ExprKind::Return`. +/// +/// # Before +/// ```text +/// expr tree rooted at expr_id +/// ``` +/// # After +/// ```text +/// unchanged +/// ``` +/// # Requires +/// - `expr_id` is valid in `lookup`. +/// +/// # Ensures +/// - Returns `true` iff `ExprKind::Return(_)` appears at any depth outside +/// closure boundaries. +/// - Does not recurse into `ExprKind::Closure`: captures are +/// [`qsc_fir::fir::LocalVarId`]s, not sub-expressions, and the closure +/// body lives in a separate callable. +/// +/// # Mutations +/// - None (read-only). +pub(super) fn contains_return_in_expr(lookup: &impl PackageLookup, expr_id: ExprId) -> bool { + let expr = lookup.get_expr(expr_id); + match &expr.kind { + ExprKind::Return(_) => true, + ExprKind::Array(exprs) | ExprKind::ArrayLit(exprs) | ExprKind::Tuple(exprs) => { + exprs.iter().any(|&e| contains_return_in_expr(lookup, e)) + } + ExprKind::ArrayRepeat(a, b) + | ExprKind::Assign(a, b) + | ExprKind::AssignOp(_, a, b) + | ExprKind::BinOp(_, a, b) + | ExprKind::Call(a, b) + | ExprKind::Index(a, b) + | ExprKind::AssignField(a, _, b) + | ExprKind::UpdateField(a, _, b) => { + contains_return_in_expr(lookup, *a) || contains_return_in_expr(lookup, *b) + } + ExprKind::AssignIndex(a, b, c) | ExprKind::UpdateIndex(a, b, c) => { + contains_return_in_expr(lookup, *a) + || contains_return_in_expr(lookup, *b) + || contains_return_in_expr(lookup, *c) + } + ExprKind::Block(block_id) => contains_return_in_block(lookup, *block_id), + ExprKind::Closure(_, _) | ExprKind::Hole | ExprKind::Lit(_) | ExprKind::Var(_, _) => false, + ExprKind::Fail(e) | ExprKind::Field(e, _) | ExprKind::UnOp(_, e) => { + contains_return_in_expr(lookup, *e) + } + ExprKind::If(cond, body, otherwise) => { + contains_return_in_expr(lookup, *cond) + || contains_return_in_expr(lookup, *body) + || otherwise.is_some_and(|e| contains_return_in_expr(lookup, e)) + } + ExprKind::Range(start, step, end) => [start, step, end] + .into_iter() + .flatten() + .any(|&e| contains_return_in_expr(lookup, e)), + ExprKind::Struct(_, copy, fields) => { + copy.is_some_and(|c| contains_return_in_expr(lookup, c)) + || fields + .iter() + .any(|fa| contains_return_in_expr(lookup, fa.value)) + } + ExprKind::String(components) => components.iter().any(|c| match c { + StringComponent::Expr(e) => contains_return_in_expr(lookup, *e), + StringComponent::Lit(_) => false, + }), + ExprKind::While(cond, body) => { + contains_return_in_expr(lookup, *cond) || contains_return_in_block(lookup, *body) + } + } +} + +#[cfg(test)] +mod tests { + use super::{contains_return_in_block, contains_return_in_expr, contains_return_in_stmt}; + use crate::test_utils::{PipelineStage, compile_and_run_pipeline_to}; + use indoc::indoc; + use qsc_fir::fir::{ + BlockId, CallableImpl, ExprKind, ItemKind, Package, PackageLookup, StmtKind, + }; + + fn find_body_block_id(package: &Package, callable_name: &str) -> BlockId { + let decl = package + .items + .values() + .find_map(|item| match &item.kind { + ItemKind::Callable(decl) if decl.name.name.as_ref() == callable_name => Some(decl), + _ => None, + }) + .unwrap_or_else(|| panic!("callable '{callable_name}' not found")); + + let CallableImpl::Spec(spec_impl) = &decl.implementation else { + panic!("callable '{callable_name}' should have a body") + }; + + spec_impl.body.block + } + + #[test] + fn contains_return_in_stmt_detects_local_initializer_return() { + let source = indoc! {r#" + namespace Test { + function Main() : Int { + let x = if true { + return 1; + } else { + 0 + }; + x + } + } + "#}; + + let (store, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::Mono); + let package = store.get(pkg_id); + let main_block_id = find_body_block_id(package, "Main"); + let main_block = package.get_block(main_block_id); + + let local_stmt_id = main_block + .stmts + .iter() + .copied() + .find(|stmt_id| matches!(package.get_stmt(*stmt_id).kind, StmtKind::Local(_, _, _))) + .expect("expected Main body to contain a Local initializer statement"); + + assert!( + contains_return_in_stmt(package, local_stmt_id), + "Local initializer with a return-bearing if-expression should be detected" + ); + assert!( + contains_return_in_block(package, main_block_id), + "Main block should report a reachable return through the Local initializer" + ); + } + + #[test] + fn contains_return_in_expr_does_not_descend_into_closure_body() { + let source = indoc! {r#" + namespace Test { + function Add(a : Int, b : Int) : Int { + if a == 0 { + return b; + } + a + b + } + + function Main() : Int { + let f = x -> Add(x, 1); + f(2) + } + } + "#}; + + let (store, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::Mono); + let package = store.get(pkg_id); + + let main_block_id = find_body_block_id(package, "Main"); + let main_block = package.get_block(main_block_id); + let closure_expr_id = main_block + .stmts + .iter() + .find_map(|stmt_id| match package.get_stmt(*stmt_id).kind { + StmtKind::Local(_, _, init_expr_id) + if matches!(package.get_expr(init_expr_id).kind, ExprKind::Closure(_, _)) => + { + Some(init_expr_id) + } + _ => None, + }) + .expect("expected Main body to contain a closure initializer"); + + assert!( + !contains_return_in_expr(package, closure_expr_id), + "closure expressions should be treated as leaves by return detection" + ); + + let add_block_id = find_body_block_id(package, "Add"); + assert!( + contains_return_in_block(package, add_block_id), + "sanity check: Add should still contain a return before return_unify" + ); + } +} diff --git a/source/compiler/qsc_fir_transforms/src/return_unify/normalize.rs b/source/compiler/qsc_fir_transforms/src/return_unify/normalize.rs new file mode 100644 index 0000000000..df732c3b23 --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/return_unify/normalize.rs @@ -0,0 +1,665 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! Hoist-returns pre-pass for the return-unification pass. +//! +//! Rewrites every callable-body block so that any `ExprKind::Return` +//! surviving in a compound (non-statement-carrying) position is lifted to a +//! bare `return v;` statement at the enclosing statement boundary. After +//! this pass, `Return` only appears as: +//! +//! * a `StmtKind::Semi`/`StmtKind::Expr` whose expression is `ExprKind::Return(_)`, +//! * the trailing expression of a block reached through `ExprKind::Block`, +//! * a branch of `ExprKind::If`, or +//! * the body of `ExprKind::While`. +//! +//! The downstream strategy pass (`transform_block_if_else` / +//! `transform_block_with_flags`) consumes that statement-level shape. +//! +//! ## Match exhaustiveness +//! +//! [`hoist_in_expr`] is an exhaustive match over every `ExprKind` variant +//! — no wildcard arm — so introducing a new variant forces a compile error +//! here and at [`super::detect::contains_return_in_expr`]. +//! +//! ## Short-circuit special cases +//! +//! The logical `and` / `or` operators evaluate their right-hand side +//! conditionally. A Return in the RHS is handled by rewriting the `BinOp` +//! in place to an equivalent `if` that the strategy pass consumes: +//! +//! ```text +//! a and (return v) → if a { return v } else { false } +//! a or (return v) → if a { true } else { return v } +//! ``` +//! +//! A Return in the LHS evaluates unconditionally and is hoisted without a +//! guard. +//! +//! ## If / While condition returns +//! +//! A Return in the *condition* of an `If` or `While` fires before either +//! branch / the loop body ever runs. +//! +//! * For `If`, the hoist rewrites the expression in place to a `Block` +//! whose statements are the hoisted condition (ending in +//! `Semi(Return(v))`) plus a trailing default value of the original `If` +//! type, preserving the enclosing block-tail invariant. +//! * For `While`, the hoist lifts condition returns directly to statement +//! boundary (same as other compounds) so downstream rewriting preserves +//! callable-level early-exit semantics. + +#[cfg(test)] +mod tests; + +#[cfg(test)] +mod shape_tests; + +use qsc_fir::{ + assigner::Assigner, + fir::{ + BinOp, Expr, ExprId, ExprKind, Ident, Mutability, Package, PackageId, PackageLookup, Pat, + PatId, PatKind, Res, Stmt, StmtId, StmtKind, StringComponent, + }, + ty::Ty, +}; + +use crate::{ + EMPTY_EXEC_RANGE, + fir_builder::{alloc_block, alloc_bool_lit, alloc_expr_stmt, alloc_semi_stmt}, +}; +use qsc_data_structures::span::Span; +use std::rc::Rc; + +use super::detect::contains_return_in_expr; + +/// Iteration bound for the fixpoint loop — each pass removes at least one +/// compound-position `Return`, so the total expression count dominates. +fn fixpoint_bound(package: &Package) -> usize { + package.exprs.iter().count() + package.stmts.iter().count() + 1 +} + +/// Hoist every compound-position `Return` to its enclosing statement boundary. +/// +/// Runs to fixpoint across `block_id` and all transitively reachable +/// sub-blocks. +/// +/// # Before +/// ```text +/// { let x = f(return v); rest } +/// ``` +/// # After +/// ```text +/// { let _ = f; return v; } // compound-position Return lifted to Semi +/// ``` +/// # Requires +/// - `block_id` is a valid block in `package`. +/// +/// # Ensures +/// - Returns `true` iff any statement was rewritten. +/// - No `ExprKind::Return` remains in a compound (non-statement-carrying) +/// position; surviving Returns sit only at statement boundaries or inside +/// `Block`/`If`/`While` (which the strategy pass handles). +/// - Panics when the fixpoint bound is exceeded. +/// +/// # Mutations +/// - Rewrites `Block.stmts` for each reachable block that hoists a Return. +/// - Allocates fresh statements and expressions through `assigner`. +pub(super) fn hoist_returns_to_statement_boundary( + package: &mut Package, + assigner: &mut Assigner, + package_id: PackageId, + block_id: qsc_fir::fir::BlockId, +) -> bool { + let bound = fixpoint_bound(package); + let mut changed_any = false; + for _ in 0..bound { + let blocks = collect_reachable_blocks(package, block_id); + let mut changed_this_iter = false; + for b in blocks { + if hoist_block_once(package, assigner, package_id, b) { + changed_this_iter = true; + } + } + if !changed_this_iter { + return changed_any; + } + changed_any = true; + } + panic!("hoist_returns_to_statement_boundary exceeded fixpoint bound"); +} + +/// Collects every block transitively reachable from `root` without crossing +/// a closure boundary. The root itself is always included first. +fn collect_reachable_blocks( + package: &Package, + root: qsc_fir::fir::BlockId, +) -> Vec { + let mut out = Vec::new(); + let mut seen = rustc_hash::FxHashSet::default(); + visit_block_for_collect(package, root, &mut out, &mut seen); + out +} + +fn visit_block_for_collect( + package: &Package, + block_id: qsc_fir::fir::BlockId, + out: &mut Vec, + seen: &mut rustc_hash::FxHashSet, +) { + if !seen.insert(block_id) { + return; + } + out.push(block_id); + let stmts = package.get_block(block_id).stmts.clone(); + for stmt_id in stmts { + let stmt_kind = package.get_stmt(stmt_id).kind.clone(); + match stmt_kind { + StmtKind::Expr(e) | StmtKind::Semi(e) | StmtKind::Local(_, _, e) => { + visit_expr_for_collect(package, e, out, seen); + } + StmtKind::Item(_) => {} + } + } +} + +fn visit_expr_for_collect( + package: &Package, + expr_id: ExprId, + out: &mut Vec, + seen: &mut rustc_hash::FxHashSet, +) { + let kind = package.get_expr(expr_id).kind.clone(); + match kind { + ExprKind::Array(exprs) | ExprKind::ArrayLit(exprs) | ExprKind::Tuple(exprs) => { + for e in exprs { + visit_expr_for_collect(package, e, out, seen); + } + } + ExprKind::ArrayRepeat(a, b) + | ExprKind::Assign(a, b) + | ExprKind::AssignOp(_, a, b) + | ExprKind::BinOp(_, a, b) + | ExprKind::Call(a, b) + | ExprKind::Index(a, b) + | ExprKind::AssignField(a, _, b) + | ExprKind::UpdateField(a, _, b) => { + visit_expr_for_collect(package, a, out, seen); + visit_expr_for_collect(package, b, out, seen); + } + ExprKind::AssignIndex(a, b, c) | ExprKind::UpdateIndex(a, b, c) => { + visit_expr_for_collect(package, a, out, seen); + visit_expr_for_collect(package, b, out, seen); + visit_expr_for_collect(package, c, out, seen); + } + ExprKind::Block(b) => visit_block_for_collect(package, b, out, seen), + ExprKind::Closure(_, _) | ExprKind::Hole | ExprKind::Lit(_) | ExprKind::Var(_, _) => {} + ExprKind::Fail(e) | ExprKind::Field(e, _) | ExprKind::Return(e) | ExprKind::UnOp(_, e) => { + visit_expr_for_collect(package, e, out, seen); + } + ExprKind::If(cond, body, otherwise) => { + visit_expr_for_collect(package, cond, out, seen); + visit_expr_for_collect(package, body, out, seen); + if let Some(e) = otherwise { + visit_expr_for_collect(package, e, out, seen); + } + } + ExprKind::Range(start, step, end) => { + for e in [start, step, end].into_iter().flatten() { + visit_expr_for_collect(package, e, out, seen); + } + } + ExprKind::Struct(_, copy, fields) => { + if let Some(c) = copy { + visit_expr_for_collect(package, c, out, seen); + } + for fa in fields { + visit_expr_for_collect(package, fa.value, out, seen); + } + } + ExprKind::String(components) => { + for component in components { + if let StringComponent::Expr(e) = component { + visit_expr_for_collect(package, e, out, seen); + } + } + } + ExprKind::While(cond, block) => { + visit_expr_for_collect(package, cond, out, seen); + visit_block_for_collect(package, block, out, seen); + } + } +} + +/// Runs one hoist pass over a single block's direct statement list. +/// +/// Does not descend into nested blocks — those are visited independently by +/// the fixpoint driver. +fn hoist_block_once( + package: &mut Package, + assigner: &mut Assigner, + package_id: PackageId, + block_id: qsc_fir::fir::BlockId, +) -> bool { + let stmts = package.get_block(block_id).stmts.clone(); + let mut new_stmts: Vec = Vec::with_capacity(stmts.len()); + let mut changed = false; + for stmt_id in stmts { + if let Some(replacement) = hoist_stmt(package, assigner, package_id, stmt_id) { + new_stmts.extend(replacement); + changed = true; + } else { + new_stmts.push(stmt_id); + } + } + if changed { + let block = package.blocks.get_mut(block_id).expect("block not found"); + block.stmts = new_stmts; + } + changed +} + +/// Attempts to hoist any compound-position `Return` reachable from the +/// statement's surface expression. Returns `Some(replacement_stmts)` if the +/// statement must be replaced, where the last entry is the bare `return v;`. +fn hoist_stmt( + package: &mut Package, + assigner: &mut Assigner, + package_id: PackageId, + stmt_id: StmtId, +) -> Option> { + let (surface, is_bare_return_form) = match &package.get_stmt(stmt_id).kind { + StmtKind::Expr(e) | StmtKind::Semi(e) => { + let is_return = matches!(package.get_expr(*e).kind, ExprKind::Return(_)); + (*e, is_return) + } + StmtKind::Local(_, _, e) => (*e, false), + StmtKind::Item(_) => return None, + }; + + // When the statement is already `Semi(Return(v))` / `Expr(Return(v))`, + // the Return is at the statement boundary. Recurse into `inner` rather + // than `surface`: any hoistable Return inside `inner` fires before the + // outer Return evaluates, so its emitted statements (which already end + // in a bare `return ...;`) supersede the outer return entirely. + // + // If `inner` is a statement-carrying construct (`Block`/`If`/`While`) + // whose internal Returns sit at statement boundaries, `hoist_in_expr` + // returns `None` even though `inner` still contains Returns. The + // strategy pass cannot consume Returns sitting under a Return wrapper, + // so pin `inner` to a fresh `let __ret_hoist = inner;` binding and + // return the bound value. The strategy pass then rewrites the Local's + // initializer through its `LocalInit` handling, and the trailing + // `Semi(Return(Var))` is canonical. + // + // If `inner` has no Returns at all, the statement is already canonical + // — returning `Some` with a fresh Semi(Return(inner)) wrapping the same + // expression would let the fixpoint re-replace the statement forever. + if is_bare_return_form { + let ExprKind::Return(inner) = package.get_expr(surface).kind else { + unreachable!() + }; + if let Some(stmts) = hoist_in_expr(package, assigner, package_id, inner) { + return Some(stmts); + } + if !contains_return_in_expr(package, inner) { + return None; + } + return Some(bind_inner_and_return(package, assigner, surface, inner)); + } + + hoist_in_expr(package, assigner, package_id, surface) +} + +/// Hoist any compound-position `Return` out of `expr_id`. +/// +/// # Before +/// ```text +/// f(a, return v, c) +/// ``` +/// # After +/// ```text +/// [let _ = a; return v;] // caller splices into enclosing block.stmts +/// ``` +/// # Requires +/// - `expr_id` is valid in `package`. +/// +/// # Ensures +/// - Returns `Some(stmts)` ending in `Semi(Return(..))` when a Return was lifted. +/// - Returns `None` when the subtree is return-free or the only Returns sit +/// behind a statement-carrying construct (`Block`, `If`, `While`) which the +/// downstream strategy pass handles. +/// - Preserves left-to-right evaluation order of earlier operands via +/// discard-`let` bindings; operands after the hoist point are dropped +/// because their results are dead. +/// - Short-circuit `and`/`or` RHS Returns are guarded with an `if`; LHS +/// Returns are unconditional. +/// +/// # Mutations +/// - Allocates new statements and expressions through `assigner`. +/// - Does not rewrite `expr_id`'s own node in place. +#[allow(clippy::match_same_arms)] // Statement-carrying vs leaf arms kept distinct for clarity. +fn hoist_in_expr( + package: &mut Package, + assigner: &mut Assigner, + package_id: PackageId, + expr_id: ExprId, +) -> Option> { + if !contains_return_in_expr(package, expr_id) { + return None; + } + let kind = package.get_expr(expr_id).kind.clone(); + match kind { + ExprKind::Return(inner) => { + // Degenerate `return (return x)`: inner Return fires first. + if let Some(inner_stmts) = hoist_in_expr(package, assigner, package_id, inner) { + return Some(inner_stmts); + } + // Re-use the existing Return expression as a Semi statement. + let stmt = alloc_semi_stmt(package, assigner, expr_id, Span::default()); + Some(vec![stmt]) + } + + // Statement-carrying Block: leave to the strategy pass. + ExprKind::Block(_) => None, + + // If: the strategy pass handles Return in branches, but we must + // hoist any Return sitting in the *condition* slot because a + // condition-Return fires before either branch evaluates. Rewrite + // the whole If in place to a `Block` expression whose statements + // run the hoist and whose trailing expression supplies a default of + // the original type so the enclosing block's tail invariant is + // preserved. + ExprKind::If(cond, _, _) => hoist_in_cond(package, assigner, package_id, expr_id, cond), + // While: lift condition returns directly to statement boundary. + // Rewriting While-in-place to `Block` can hide callable-level + // early-exit semantics when the While is in statement position. + ExprKind::While(cond, _) => hoist_in_expr(package, assigner, package_id, cond), + + // Leaves: no sub-expression can hold a Return. + ExprKind::Closure(_, _) | ExprKind::Hole | ExprKind::Lit(_) | ExprKind::Var(_, _) => None, + + // Short-circuit logical operators: rewrite `a and/or b` in place to + // an equivalent `if` when the RHS (short-circuited operand) holds + // the Return, so the Return ends up in a branch of an If that the + // strategy pass consumes while the BinOp's `Bool` type is preserved. + ExprKind::BinOp(BinOp::AndL, a, b) => { + hoist_short_circuit(package, assigner, package_id, expr_id, a, b, true) + } + ExprKind::BinOp(BinOp::OrL, a, b) => { + hoist_short_circuit(package, assigner, package_id, expr_id, a, b, false) + } + + // Two-operand compounds evaluated left-to-right. + ExprKind::BinOp(_, a, b) + | ExprKind::Call(a, b) + | ExprKind::Index(a, b) + | ExprKind::ArrayRepeat(a, b) + | ExprKind::Assign(a, b) + | ExprKind::AssignOp(_, a, b) + | ExprKind::AssignField(a, _, b) + | ExprKind::UpdateField(a, _, b) => hoist_n_ary(package, assigner, package_id, &[a, b]), + + // Three-operand compounds evaluated left-to-right. + ExprKind::AssignIndex(a, b, c) | ExprKind::UpdateIndex(a, b, c) => { + hoist_n_ary(package, assigner, package_id, &[a, b, c]) + } + + // N-ary compounds. + ExprKind::Array(exprs) | ExprKind::ArrayLit(exprs) | ExprKind::Tuple(exprs) => { + hoist_n_ary(package, assigner, package_id, &exprs) + } + + // Single-operand compounds — the operand's result is dead after a + // Return fires, so forward its hoist result directly. + ExprKind::UnOp(_, e) | ExprKind::Field(e, _) | ExprKind::Fail(e) => { + hoist_in_expr(package, assigner, package_id, e) + } + + // Optional operands in left-to-right order. + ExprKind::Range(start, step, end) => { + let operands: Vec = [start, step, end].into_iter().flatten().collect(); + hoist_n_ary(package, assigner, package_id, &operands) + } + + // `copy` (if present) evaluates before field values, in source order. + ExprKind::Struct(_, copy, fields) => { + let mut operands: Vec = Vec::with_capacity(fields.len() + 1); + if let Some(c) = copy { + operands.push(c); + } + for fa in &fields { + operands.push(fa.value); + } + hoist_n_ary(package, assigner, package_id, &operands) + } + + // Interpolated string components in source order. + ExprKind::String(components) => { + let operands: Vec = components + .into_iter() + .filter_map(|c| match c { + StringComponent::Expr(e) => Some(e), + StringComponent::Lit(_) => None, + }) + .collect(); + hoist_n_ary(package, assigner, package_id, &operands) + } + } +} + +/// Hoists a compound with operands evaluated strictly left-to-right. +/// +/// Finds the first operand whose subtree contains a hoistable `Return`. +/// Every earlier operand is bound to a discard-pattern `let` so its +/// side-effects execute in original source order; operands after the hoist +/// point are dead and dropped. +fn hoist_n_ary( + package: &mut Package, + assigner: &mut Assigner, + package_id: PackageId, + operands: &[ExprId], +) -> Option> { + for (i, &op) in operands.iter().enumerate() { + if let Some(op_stmts) = hoist_in_expr(package, assigner, package_id, op) { + let mut out: Vec = Vec::with_capacity(i + op_stmts.len()); + for &pre in &operands[..i] { + out.push(create_discard_let_stmt(package, assigner, pre)); + } + out.extend(op_stmts); + return Some(out); + } + } + None +} + +/// Handles `and`/`or` short-circuit `BinOp`s. +/// +/// * LHS Return is unconditional — lifted with no guard. +/// * RHS Return short-circuits: `and` fires only when LHS is `true`, +/// `or` fires only when LHS is `false`. We preserve the `BinOp`'s `Bool` +/// type and semantics by rewriting in place: +/// +/// ```text +/// a and b → if a { b } else { false } +/// a or b → if a { true } else { b } +/// ``` +/// +/// The Return now sits in a branch of an `If`, which the strategy pass +/// consumes, so the hoist itself does not need to emit statements. +fn hoist_short_circuit( + package: &mut Package, + assigner: &mut Assigner, + package_id: PackageId, + expr_id: ExprId, + a: ExprId, + b: ExprId, + is_and: bool, +) -> Option> { + // LHS always evaluates — an LHS Return is unconditional. + if let Some(stmts_a) = hoist_in_expr(package, assigner, package_id, a) { + return Some(stmts_a); + } + // LHS is clean; any hoistable Return must sit in the RHS. + if !contains_return_in_expr(package, b) { + return None; + } + let lit_expr = { + let value = !is_and; + alloc_bool_lit(package, assigner, value, Span::default()) + }; + let (then_id, else_id) = if is_and { (b, lit_expr) } else { (lit_expr, b) }; + let expr = package.exprs.get_mut(expr_id).expect("expr not found"); + expr.kind = ExprKind::If(a, then_id, Some(else_id)); + None +} + +/// Handler for `If` condition returns. If the condition expression holds a +/// `Return`, rewrites the surrounding expression in place to a `Block` +/// expression whose statements execute the hoisted return and whose +/// trailing expression provides a default value of the original expression's +/// type so the enclosing block's tail invariant is preserved. +/// +/// The branches / loop body are deliberately dropped: if the condition +/// `Return` fires, control transfers out of the callable before any of +/// them ever evaluates. +fn hoist_in_cond( + package: &mut Package, + assigner: &mut Assigner, + package_id: PackageId, + expr_id: ExprId, + cond: ExprId, +) -> Option> { + let stmts = hoist_in_expr(package, assigner, package_id, cond)?; + let orig_ty = package.get_expr(expr_id).ty.clone(); + let mut block_stmts = stmts; + if orig_ty != Ty::UNIT { + let default = super::create_default_value( + package, + assigner, + package_id, + &orig_ty, + &rustc_hash::FxHashMap::default(), + ) + .unwrap_or_else(|| { + panic!("return_unify: unsupported return type in hoisted condition: {orig_ty:?}") + }); + block_stmts.push(alloc_expr_stmt(package, assigner, default, Span::default())); + } + let block_id = { + let ty: &Ty = &orig_ty; + alloc_block(package, assigner, block_stmts, ty.clone(), Span::default()) + }; + let expr = package.exprs.get_mut(expr_id).expect("expr not found"); + expr.kind = ExprKind::Block(block_id); + // `expr.ty` already matches `orig_ty`; leave it as-is. + None +} + +/// Creates `let _ = expr_id;` — a discard-pattern `Local` whose sole +/// purpose is to preserve the operand's evaluation-order side-effects when +/// a later operand hoists a `Return` that discards the overall compound. +fn create_discard_let_stmt( + package: &mut Package, + assigner: &mut Assigner, + expr_id: ExprId, +) -> StmtId { + let ty = package.get_expr(expr_id).ty.clone(); + let pat_id: PatId = assigner.next_pat(); + package.pats.insert( + pat_id, + Pat { + id: pat_id, + span: Span::default(), + ty, + kind: PatKind::Discard, + }, + ); + let stmt_id = assigner.next_stmt(); + package.stmts.insert( + stmt_id, + Stmt { + id: stmt_id, + span: Span::default(), + kind: StmtKind::Local(Mutability::Immutable, pat_id, expr_id), + exec_graph_range: EMPTY_EXEC_RANGE, + }, + ); + stmt_id +} + +/// Pins a statement-carrying `inner` (Block/If/While with internal Returns) +/// to a fresh immutable `let __ret_hoist = inner;` binding and rewrites +/// `return_expr` to `Return(Var(__ret_hoist))`, yielding a two-statement +/// replacement for the original `Semi(Return(inner))`. +/// +/// # Why +/// The strategy pass cannot rewrite Returns that sit under a `Return` +/// wrapper (its classifier peeks at the top-level stmt expression kind and +/// stops). Binding `inner` to a Local instead exposes those Returns through +/// the `LocalInit` path, which the strategy pass does know how to rewrite. +/// +/// # Mutations +/// - Allocates a fresh `LocalVarId`, `PatId`, `StmtId`, and a `Var` `ExprId`. +/// - Mutates `return_expr`'s kind in place from `Return(inner)` to +/// `Return(Var(var_id))`. +/// +/// # Returns +/// Two statements, in order: the new `Local(__ret_hoist := inner)` and +/// a fresh `Semi(Return(Var))` reusing `return_expr`. +fn bind_inner_and_return( + package: &mut Package, + assigner: &mut Assigner, + return_expr: ExprId, + inner: ExprId, +) -> Vec { + let inner_ty = package.get_expr(inner).ty.clone(); + let local_var_id = assigner.next_local(); + let pat_id = assigner.next_pat(); + package.pats.insert( + pat_id, + Pat { + id: pat_id, + span: Span::default(), + ty: inner_ty.clone(), + kind: PatKind::Bind(Ident { + id: local_var_id, + span: Span::default(), + name: Rc::from("__ret_hoist"), + }), + }, + ); + let local_stmt_id = assigner.next_stmt(); + package.stmts.insert( + local_stmt_id, + Stmt { + id: local_stmt_id, + span: Span::default(), + kind: StmtKind::Local(Mutability::Immutable, pat_id, inner), + exec_graph_range: EMPTY_EXEC_RANGE, + }, + ); + + let var_expr_id = assigner.next_expr(); + package.exprs.insert( + var_expr_id, + Expr { + id: var_expr_id, + span: Span::default(), + ty: inner_ty, + kind: ExprKind::Var(Res::Local(local_var_id), Vec::new()), + exec_graph_range: EMPTY_EXEC_RANGE, + }, + ); + + // Rewrite the existing Return expression in place so it now wraps the + // Var, then wrap it in a fresh Semi statement. + let ret = package + .exprs + .get_mut(return_expr) + .expect("return expr not found"); + ret.kind = ExprKind::Return(var_expr_id); + let return_stmt_id = alloc_semi_stmt(package, assigner, return_expr, Span::default()); + + vec![local_stmt_id, return_stmt_id] +} diff --git a/source/compiler/qsc_fir_transforms/src/return_unify/normalize/shape_tests.rs b/source/compiler/qsc_fir_transforms/src/return_unify/normalize/shape_tests.rs new file mode 100644 index 0000000000..b3b833a0a4 --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/return_unify/normalize/shape_tests.rs @@ -0,0 +1,244 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use crate::{ + PipelineStage, + return_unify::{tests::assert_no_reachable_returns, unify_returns}, + test_utils::compile_and_run_pipeline_to, +}; +use expect_test::{Expect, expect}; +use indoc::indoc; +use qsc_fir::assigner::Assigner; + +/// Compiles Q# source through `Mono`, captures a pretty-printed snapshot of +/// the package, runs `unify_returns` directly, captures a second snapshot, +/// and asserts the concatenated `BEFORE` / `AFTER` string matches `expect`. +/// +/// Shape-sensitive alternative to [`check_no_returns`]. Prefer behavior-only +/// assertions for the majority of tests; reserve this for cases where the +/// rewriting shape is itself under test. +fn check_before_after(source: &str, expect: &Expect) { + let (mut store, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::Mono); + let before = crate::pretty::write_package_qsharp(&store, pkg_id); + let mut assigner = Assigner::from_package(store.get(pkg_id)); + let errors = unify_returns(&mut store, pkg_id, &mut assigner); + assert!( + errors.is_empty(), + "return_unify shape test produced errors: {errors:?}" + ); + assert_no_reachable_returns(&store, pkg_id); + let after = crate::pretty::write_package_qsharp(&store, pkg_id); + let combined = format!("BEFORE:\n{before}\nAFTER:\n{after}"); + expect.assert_eq(&combined); +} + +#[test] +fn hoist_return_in_call_argument_shape_snapshot() { + // Flagship shape test — the same Q# shape as + // `hoist_return_in_call_argument`, but asserting the BEFORE/AFTER FIR + // pretty-print to lock the hoist shape. + check_before_after( + indoc! {r#" + namespace Test { + function Add(a : Int, b : Int) : Int { a + b } + @EntryPoint() + function Main() : Int { + let x = Add((return 1), 2); + x + } + } + "#}, + &expect![[r#" + BEFORE: + // namespace Test + function Add(a : Int, b : Int) : Int { + body { + a + b + } + } + function Main() : Int { + body { + let x : Int = Add(return 1, 2); + x + } + } + // entry + Main() + + AFTER: + // namespace Test + function Add(a : Int, b : Int) : Int { + body { + a + b + } + } + function Main() : Int { + body { + let _ : ((Int, Int) -> Int) = Add; + 1 + } + } + // entry + Main() + "#]], + ); +} + +#[test] +fn while_condition_return_shape_snapshot() { + check_before_after( + indoc! {r#" + namespace Test { + @EntryPoint() + function Main() : Int { + while if true { + if true { + return 31; + } else { + false + } + } else { + false + } { + let _ = 0; + } + 0 + } + } + "#}, + &expect![[r#" + BEFORE: + // namespace Test + function Main() : Int { + body { + while if true { + if true { + return 31; + } else { + false + } + + } else { + false + } + { + let _ : Int = 0; + } + + 0 + } + } + // entry + Main() + + AFTER: + // namespace Test + function Main() : Int { + body { + mutable __has_returned : Bool = false; + mutable __ret_val : Int = 0; + while not __has_returned and if true { + if true { + { + __ret_val = 31; + __has_returned = true; + }; + } else { + false + } + + } else { + false + } + { + let _ : Int = 0; + } + + let __trailing_result : Int = 0; + if __has_returned __ret_val else __trailing_result + } + } + // entry + Main() + "#]], + ); +} + +#[test] +fn while_local_initializer_return_shape_snapshot() { + check_before_after( + indoc! {r#" + namespace Test { + function Add(a : Int, b : Int) : Int { a + b } + + @EntryPoint() + function Main() : Int { + mutable i = 0; + while i < 3 { + let _ = if i == 1 { + Add((return 42), i) + }; + i += 1; + } + i + 5 + } + } + "#}, + &expect![[r#" + BEFORE: + // namespace Test + function Add(a : Int, b : Int) : Int { + body { + a + b + } + } + function Main() : Int { + body { + mutable i : Int = 0; + while i < 3 { + let _ : Unit = if i == 1 { + Add(return 42, i) + }; + i += 1; + } + + i + 5 + } + } + // entry + Main() + + AFTER: + // namespace Test + function Add(a : Int, b : Int) : Int { + body { + a + b + } + } + function Main() : Int { + body { + mutable __has_returned : Bool = false; + mutable __ret_val : Int = 0; + mutable i : Int = 0; + while not __has_returned and i < 3 { + let _ : Unit = if i == 1 { + let _ : ((Int, Int) -> Int) = Add; + { + __ret_val = 42; + __has_returned = true; + }; + }; + if not __has_returned { + i += 1; + }; + } + + let __trailing_result : Int = i + 5; + if __has_returned __ret_val else __trailing_result + } + } + // entry + Main() + "#]], + ); +} diff --git a/source/compiler/qsc_fir_transforms/src/return_unify/normalize/tests.rs b/source/compiler/qsc_fir_transforms/src/return_unify/normalize/tests.rs new file mode 100644 index 0000000000..07b32910f6 --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/return_unify/normalize/tests.rs @@ -0,0 +1,57 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +pub(super) use crate::return_unify::tests::{ + check_no_returns_q, check_structure, compile_return_unified, +}; +pub(super) use expect_test::{Expect, expect}; +pub(super) use indoc::indoc; + +use qsc_data_structures::language_features::LanguageFeatures; +use qsc_parse::namespaces; + +mod fixpoint; +mod flag_strategy; +mod hoist_expression; +mod nested_constructs; +mod regression_and_depth; +mod three_level; +mod three_level_mixed; + +// Each of the following tests exercises the `normalize::hoist_returns_to_statement_boundary` +// pre-pass by placing a `Return` inside a compound expression position. The +// invariant `check_no_returns` asserts that the combined hoist + transform +// produces PostReturnUnify-clean FIR (no `ExprKind::Return` survives). + +fn rendered_qsharp_parse_diagnostics(rendered: &str) -> Vec { + let rendered_without_entry = if let Some((before_entry, _)) = rendered.split_once("// entry\n") + { + before_entry.trim_end().to_string() + } else { + rendered.to_string() + }; + + let (_namespaces, errors) = namespaces( + &rendered_without_entry, + Some("roundtrip.qs"), + LanguageFeatures::default(), + ); + errors + .into_iter() + .map(|error| format!("{error:?}")) + .collect() +} + +pub(super) fn check_no_returns_q_roundtrip(source: &str, expect: &Expect) { + check_no_returns_q(source, expect); + + let (store, pkg_id) = compile_return_unified(source); + let rendered = crate::pretty::write_package_qsharp_parseable(&store, pkg_id); + let diagnostics = rendered_qsharp_parse_diagnostics(&rendered); + + assert!( + diagnostics.is_empty(), + "generated Q# should parse without diagnostics:\n{}\n\nrendered:\n{rendered}", + diagnostics.join("\n") + ); +} diff --git a/source/compiler/qsc_fir_transforms/src/return_unify/normalize/tests/fixpoint.rs b/source/compiler/qsc_fir_transforms/src/return_unify/normalize/tests/fixpoint.rs new file mode 100644 index 0000000000..c5452f0de9 --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/return_unify/normalize/tests/fixpoint.rs @@ -0,0 +1,349 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! Fixpoint termination boundary tests. + +use super::*; + +// The following tests exercise the `hoist_stmt` boundary case where the +// surface statement is already `Semi(Return(inner))` / `Expr(Return(inner))` +// and `inner` is a statement-carrying construct (`Block`, `If`, `While`) +// whose body holds a statement-level `Return`. A naive fixpoint that re- +// issues a fresh `Semi(Return(inner))` every iteration would loop forever; +// the hoist must either lift a return out of `inner` or leave the statement +// untouched so fixpoint terminates. + +#[test] +fn hoist_outer_return_wraps_if_with_return_in_then_branch() { + // `return if c { return X; } else { Y }` — the outer return wraps an + // `If` whose then-branch is a statement-level return. The strategy pass + // handles the inner return; the outer statement must stay fixed so the + // hoist fixpoint terminates. + check_no_returns_q( + indoc! {r#" + namespace Test { + import Std.Measurement.*; + @EntryPoint() + operation Main() : Int { + use q = Qubit(); + return if M(q) == One { + return 1; + } else { + 2 + }; + } + } + "#}, + &expect![[r#" + // namespace Test + operation Main() : Int { + body { + let q : Qubit = __quantum__rt__qubit_allocate(); + { + if M(q) == One { + { + let + @generated_ident_36 : Int = 1; + __quantum__rt__qubit_release(q); + @generated_ident_36 + } + + } else { + let + @generated_ident_35 : Int = { + 2 + }; + __quantum__rt__qubit_release(q); + @generated_ident_35 + } + + } + + } + } + function Length(a : Pauli[]) : Int { + body intrinsic; + } + function Length(a : Qubit[]) : Int { + body intrinsic; + } + // entry + Main() + "#]], + ); +} + +#[test] +fn hoist_outer_return_wraps_if_with_returns_in_both_branches() { + // Both branches terminate with statement-level returns inside an outer + // `return`. Exercises the cross-product of the boundary case. + check_no_returns_q( + indoc! {r#" + namespace Test { + import Std.Measurement.*; + @EntryPoint() + operation Main() : Int { + use q = Qubit(); + return if M(q) == One { + return 1; + } else { + return 2; + }; + } + } + "#}, + &expect![[r#" + // namespace Test + operation Main() : Int { + body { + let q : Qubit = __quantum__rt__qubit_allocate(); + { + if M(q) == One { + { + let + @generated_ident_37 : Int = 1; + __quantum__rt__qubit_release(q); + @generated_ident_37 + } + + } else { + { + let + @generated_ident_49 : Int = 2; + __quantum__rt__qubit_release(q); + @generated_ident_49 + } + + } + + } + + } + } + function Length(a : Pauli[]) : Int { + body intrinsic; + } + function Length(a : Qubit[]) : Int { + body intrinsic; + } + // entry + Main() + "#]], + ); +} + +#[test] +fn hoist_outer_return_wraps_block_with_stmt_level_return() { + // `return { side_effect(); return X; trailing }` — outer return wraps a + // `Block` whose statement list contains a `Semi(Return)`. The strategy + // pass handles the inner return; the outer statement must stay fixed. + check_no_returns_q( + indoc! {r#" + namespace Test { + import Std.Measurement.*; + @EntryPoint() + operation Main() : Int { + use q = Qubit(); + return { + if M(q) == One { + return 1; + } + 2 + }; + } + } + "#}, + &expect![[r#" + // namespace Test + operation Main() : Int { + body { + let q : Qubit = __quantum__rt__qubit_allocate(); + { + let + @generated_ident_36 : Int = { + if M(q) == One { + { + let + @generated_ident_37 : Int = 1; + __quantum__rt__qubit_release(q); + @generated_ident_37 + } + + } else { + 2 + } + + }; + __quantum__rt__qubit_release(q); + @generated_ident_36 + } + + } + } + function Length(a : Pauli[]) : Int { + body intrinsic; + } + function Length(a : Qubit[]) : Int { + body intrinsic; + } + // entry + Main() + "#]], + ); +} + +#[test] +fn hoist_outer_return_wraps_if_whose_condition_has_return() { + // `return if (return X) { 1 } else { 2 }` — the outer return wraps an + // `If` whose *condition* holds an unconditional return. The inner hoist + // rewrites the `If` in place to a `Block` (via `hoist_in_cond`); the + // outer statement must then terminate on the next fixpoint iteration + // instead of re-emitting a fresh Semi(Return(Block)) forever. + check_no_returns_q( + indoc! {r#" + namespace Test { + function Main() : Int { + return if (return 7) { + 1 + } else { + 2 + }; + } + } + "#}, + &expect![[r#" + // namespace Test + function Main() : Int { + body { + { + 7 + } + + } + } + // entry + Main() + "#]], + ); +} + +#[test] +fn hoist_outer_return_wraps_while_with_return_body() { + // `return while c { ...; return (); }` in a Unit-returning callable. + // The outer return wraps a `While` whose body contains a statement-level + // return. Exercises the While arm of the boundary case. + check_no_returns_q( + indoc! {r#" + namespace Test { + @EntryPoint() + operation Main() : Unit { + mutable i = 0; + return while i < 3 { + if i == 1 { + return (); + } + i += 1; + }; + } + } + "#}, + &expect![[r#" + // namespace Test + operation Main() : Unit { + body { + mutable __has_returned : Bool = false; + mutable __ret_val : Unit = (); + mutable i : Int = 0; + let __ret_hoist : Unit = while not __has_returned and i < 3 { + if i == 1 { + { + __ret_val = (); + __has_returned = true; + }; + } + + if not __has_returned { + i += 1; + }; + }; + if not __has_returned { + { + __ret_val = __ret_hoist; + __has_returned = true; + }; + }; + if __has_returned __ret_val else () + } + } + // entry + Main() + "#]], + ); +} + +#[test] +fn hoist_outer_return_wraps_nested_ifs_with_deep_stmt_return() { + // Nested `if`s inside a `return`, with a statement-level return at the + // deepest level. Verifies the fixpoint handles multi-level statement- + // carrying constructs under a bare outer return without looping. + check_no_returns_q( + indoc! {r#" + namespace Test { + import Std.Measurement.*; + @EntryPoint() + operation Main() : Int { + use q = Qubit(); + return if M(q) == One { + if M(q) == Zero { + return 1; + } + 2 + } else { + 3 + }; + } + } + "#}, + &expect![[r#" + // namespace Test + operation Main() : Int { + body { + let q : Qubit = __quantum__rt__qubit_allocate(); + { + if M(q) == One { + if M(q) == Zero { + { + let + @generated_ident_47 : Int = 1; + __quantum__rt__qubit_release(q); + @generated_ident_47 + } + + } else { + 2 + } + + } else { + let + @generated_ident_46 : Int = { + 3 + }; + __quantum__rt__qubit_release(q); + @generated_ident_46 + } + + } + + } + } + function Length(a : Pauli[]) : Int { + body intrinsic; + } + function Length(a : Qubit[]) : Int { + body intrinsic; + } + // entry + Main() + "#]], + ); +} diff --git a/source/compiler/qsc_fir_transforms/src/return_unify/normalize/tests/flag_strategy.rs b/source/compiler/qsc_fir_transforms/src/return_unify/normalize/tests/flag_strategy.rs new file mode 100644 index 0000000000..95fef8e07d --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/return_unify/normalize/tests/flag_strategy.rs @@ -0,0 +1,460 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! Flag-strategy tests: specializations, while-body returns, local-init retypes, +//! and flag-fallback edge cases. + +use super::*; + +#[test] +fn adjoint_spec_hoist_in_call_arg() { + // Return in a Call argument inside an explicit `adjoint` specialization. + check_no_returns_q( + indoc! {r#" + namespace Test { + operation Inner(x : Int, q : Qubit) : Unit is Adj { + body ... { X(q); } + adjoint self; + } + operation Outer(n : Int, q : Qubit) : Unit is Adj { + body ... { Inner(n, q); } + adjoint ... { + Inner((return ()), q); + } + } + @EntryPoint() + operation Main() : Unit { + use q = Qubit(); + Adjoint Outer(1, q); + } + } + "#}, + &expect![[r#" + // namespace Test + operation Inner(x : Int, q : Qubit) : Unit is Adj { + body { + X(q); + } + adjoint { + X(q); + } + } + operation Outer(n : Int, q : Qubit) : Unit is Adj { + body { + Inner(n, q); + } + adjoint { + let _ : ((Int, Qubit) => Unit is Adj) = Inner; + () + } + } + operation Main() : Unit { + body { + let q : Qubit = __quantum__rt__qubit_allocate(); + Adjoint Outer(1, q); + __quantum__rt__qubit_release(q); + } + } + function Length(a : Qubit[]) : Int { + body intrinsic; + } + // entry + Main() + "#]], + ); +} + +#[test] +fn controlled_spec_hoist_in_call_arg() { + // Return in a Call argument inside an explicit `controlled` specialization. + // Disposition: documented contract. Snapshot keeps current callable + // signature text, while round-trip compilation confirms validity. + check_no_returns_q_roundtrip( + indoc! {r#" + namespace Test { + operation Outer(n : Int, q : Qubit) : Unit is Ctl { + body ... { H(q); } + controlled (ctls, ...) { + Controlled H(ctls, (return ())); + } + } + @EntryPoint() + operation Main() : Unit { + use (c, q) = (Qubit(), Qubit()); + Controlled Outer([c], (1, q)); + } + } + "#}, + &expect![[r#" + // namespace Test + operation Outer(n : Int, q : Qubit) : Unit is Ctl { + body { + H(q); + } + controlled { + let _ : ((Qubit[], Qubit) => Unit is Adj + Ctl) = Controlled H; + let _ : Qubit[] = _local3; + () + } + } + operation Main() : Unit { + body { + let + @generated_ident_53 : Qubit = __quantum__rt__qubit_allocate(); + let + @generated_ident_55 : Qubit = __quantum__rt__qubit_allocate(); + let (c : Qubit, q : Qubit) = ( + @generated_ident_53, + @generated_ident_55 + ); + Controlled Outer([c], (1, q)); + __quantum__rt__qubit_release( + @generated_ident_55 + ); + __quantum__rt__qubit_release( + @generated_ident_53 + ); + } + } + function Length(a : Qubit[]) : Int { + body intrinsic; + } + // entry + Main() + "#]], + ); +} + +#[test] +fn controlled_adjoint_spec_hoist_in_call_arg() { + // Return in a Call argument inside an explicit `controlled adjoint` + // specialization. + check_no_returns_q_roundtrip( + indoc! {r#" + namespace Test { + operation Outer(n : Int, q : Qubit) : Unit is Adj + Ctl { + body ... { H(q); } + adjoint ... { H(q); } + controlled (ctls, ...) { Controlled H(ctls, q); } + controlled adjoint (ctls, ...) { + Controlled H(ctls, (return ())); + } + } + @EntryPoint() + operation Main() : Unit { + use (c, q) = (Qubit(), Qubit()); + Controlled Adjoint Outer([c], (1, q)); + } + } + "#}, + &expect![[r#" + // namespace Test + operation Outer(n : Int, q : Qubit) : Unit is Adj + Ctl { + body { + H(q); + } + adjoint { + H(q); + } + controlled { + Controlled H(_local3, q); + } + controlled adjoint { + let _ : ((Qubit[], Qubit) => Unit is Adj + Ctl) = Controlled H; + let _ : Qubit[] = _local4; + () + } + } + operation Main() : Unit { + body { + let + @generated_ident_71 : Qubit = __quantum__rt__qubit_allocate(); + let + @generated_ident_73 : Qubit = __quantum__rt__qubit_allocate(); + let (c : Qubit, q : Qubit) = ( + @generated_ident_71, + @generated_ident_73 + ); + Controlled Adjoint Outer([c], (1, q)); + __quantum__rt__qubit_release( + @generated_ident_73 + ); + __quantum__rt__qubit_release( + @generated_ident_71 + ); + } + } + function Length(a : Qubit[]) : Int { + body intrinsic; + } + // entry + Main() + "#]], + ); +} + +#[test] +fn while_body_with_call_arg_return() { + // While body containing a Call-argument Return. The outer transform + // routes this through the flag-based path because the Return sits + // inside a while body. + // Disposition: documented contract. Snapshot keeps historical identifier + // spellings, while round-trip compilation confirms generated Q# validity. + check_no_returns_q_roundtrip( + indoc! {r#" + namespace Test { + function Add(a : Int, b : Int) : Int { a + b } + function Main() : Int { + mutable i = 0; + while i < 3 { + let _ = Add((return 42), 2); + i += 1; + } + -1 + } + } + "#}, + &expect![[r#" + // namespace Test + function Add(a : Int, b : Int) : Int { + body { + a + b + } + } + function Main() : Int { + body { + mutable __has_returned : Bool = false; + mutable __ret_val : Int = 0; + mutable i : Int = 0; + while not __has_returned and i < 3 { + let _ : ((Int, Int) -> Int) = Add; + { + __ret_val = 42; + __has_returned = true; + }; + if not __has_returned { + i += 1; + }; + } + + let __trailing_result : Int = -1; + if __has_returned __ret_val else __trailing_result + } + } + // entry + Main() + "#]], + ); +} + +#[test] +fn local_init_retype_in_call_arg_fix() { + // `let x = if c { return 1 } else { 0 }; Identity(x);` — after hoist + + // if-else transform, the local `x` must hold an Int (the transformed + // initializer's new type), not the diverging type from the pre-transform + // Return. + check_no_returns_q_roundtrip( + indoc! {r#" + namespace Test { + function Identity(x : Int) : Int { x } + function Main() : Int { + let c = true; + let x = if c { return 1 } else { 0 }; + Identity(x) + } + } + "#}, + &expect![[r#" + // namespace Test + function Identity(x : Int) : Int { + body { + x + } + } + function Main() : Int { + body { + let c : Bool = true; + if c { + 1 + } else { + let x : Int = { + 0 + }; + Identity(x) + } + + } + } + // entry + Main() + "#]], + ); +} + +#[test] +fn nested_block_middle_of_block_fix() { + // `{ if c { return 1; } 2 }; let y = 3; y` — a nested Block expression + // containing an if-return-then-value sits in the middle of the outer + // block. Regression for middle-of-block nested-block rewrite must + // produce a Block whose trailing expression preserves the outer block's + // structural invariants. + check_no_returns_q_roundtrip( + indoc! {r#" + namespace Test { + function Main() : Int { + let c = true; + let _unused = { + if c { return 1; } + 2 + }; + let y = 3; + y + } + } + "#}, + &expect![[r#" + // namespace Test + function Main() : Int { + body { + mutable __has_returned : Bool = false; + mutable __ret_val : Int = 0; + let c : Bool = true; + let _unused : Int = { + if c { + { + __ret_val = 1; + __has_returned = true; + }; + } + + 2 + }; + let y : Int = if not __has_returned { + 3 + } else { + 0 + }; + let __trailing_result : Int = y; + if __has_returned __ret_val else __trailing_result + } + } + // entry + Main() + "#]], + ); +} + +#[test] +fn flag_fallback_handles_arrow_return() { + // A callable-valued Return inside a while body forces the flag-based + // fallback to synthesize a default of arrow type. `create_default_value` + // handles this by synthesizing a nop callable item of matching + // signature and using `Var(Res::Item(..))` as the `__ret_val` seed; the + // nop is never actually invoked because `__has_returned` guards every + // read of `__ret_val`. + let source = indoc! {r#" + namespace Test { + function MakeAdder(n : Int) : (Int -> Int) { + mutable i = 0; + while i < 3 { + if i == n { + return (x -> x + 1); + } + i += 1; + } + x -> x + } + @EntryPoint() + function Main() : Int { + let f = MakeAdder(1); + f(10) + } + } + "#}; + let _ = compile_return_unified(source); + check_no_returns_q( + source, + &expect![[r#" + // namespace Test + function MakeAdder(n : Int) : (Int -> Int) { + body { + mutable __has_returned : Bool = false; + mutable __ret_val : (Int -> Int) = __return_unify_nop_5; + mutable i : Int = 0; + while not __has_returned and i < 3 { + if i == n { + { + __ret_val = / * closure item = 3 captures = [] * / < lambda >; + __has_returned = true; + }; + } + + if not __has_returned { + i += 1; + }; + } + + let __trailing_result : (Int -> Int) = / * closure item = 4 captures = [] * / < lambda >; + if __has_returned __ret_val else __trailing_result + } + } + function Main() : Int { + body { + let f : (Int -> Int) = MakeAdder(1); + f(10) + } + } + function < lambda > (x : Int, ) : Int { + body { + x + 1 + } + } + function < lambda > (x : Int, ) : Int { + body { + x + } + } + function __return_unify_nop_5(_ : Int) : Int { + body { + 0 + } + } + // entry + Main() + "#]], + ); +} + +#[test] +fn flag_fallback_supports_post_return_range_local_initializer() { + let source = indoc! {r#" + namespace Test { + function Main() : Int { + mutable i = 0; + while i < 3 { + if i == 1 { + return i; + } + i += 1; + } + let r = 0..3; + 0 + } + } + "#}; + + let (store, pkg_id) = compile_return_unified(source); + let rendered = crate::pretty::write_package_qsharp(&store, pkg_id); + + assert!( + rendered.contains("let r : Range = if not __has_returned {"), + "post-return range local initializers should be guarded under the flag strategy", + ); + // After bind-then-check fix, the trailing expression is bound to __trailing_result + // before the flag check. + assert!( + rendered.contains("let __trailing_result : Int =") + && rendered.contains("if __has_returned __ret_val else __trailing_result"), + "final trailing expression should use bind-then-check pattern", + ); +} diff --git a/source/compiler/qsc_fir_transforms/src/return_unify/normalize/tests/hoist_expression.rs b/source/compiler/qsc_fir_transforms/src/return_unify/normalize/tests/hoist_expression.rs new file mode 100644 index 0000000000..fed761c321 --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/return_unify/normalize/tests/hoist_expression.rs @@ -0,0 +1,785 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! Hoist-return tests: returns inside compound expression positions. + +use super::*; + +use crate::walk_utils::for_each_expr_in_callable_impl; +use qsc_fir::fir::{ + BinOp, CallableImpl, ExprKind, ItemKind, Lit, LocalVarId, Package, PackageLookup, PatKind, Res, + StmtKind, UnOp, +}; + +fn find_main_decl(package: &Package) -> &qsc_fir::fir::CallableDecl { + package + .items + .values() + .find_map(|item| match &item.kind { + ItemKind::Callable(decl) if decl.name.name.as_ref() == "Main" => Some(decl), + _ => None, + }) + .expect("callable 'Main' not found") +} + +fn find_top_level_local_var_id( + package: &Package, + body_block_id: qsc_fir::fir::BlockId, + local_name: &str, +) -> LocalVarId { + let body_block = package.get_block(body_block_id); + body_block + .stmts + .iter() + .find_map(|&stmt_id| { + let stmt_kind = package.get_stmt(stmt_id).kind.clone(); + let StmtKind::Local(_, pat_id, _init_expr_id) = stmt_kind else { + return None; + }; + let pat = package.get_pat(pat_id); + let PatKind::Bind(ident) = &pat.kind else { + return None; + }; + (ident.name.as_ref() == local_name).then_some(ident.id) + }) + .unwrap_or_else(|| panic!("local '{local_name}' not found in Main body")) +} + +fn expr_reads_local( + package: &Package, + expr_id: qsc_fir::fir::ExprId, + expected_local: LocalVarId, +) -> bool { + let expr_kind = package.get_expr(expr_id).kind.clone(); + matches!(expr_kind, ExprKind::Var(Res::Local(local_id), _) if local_id == expected_local) +} + +fn is_not_flag_expr( + package: &Package, + expr_id: qsc_fir::fir::ExprId, + has_returned_var_id: LocalVarId, +) -> bool { + let expr_kind = package.get_expr(expr_id).kind.clone(); + let ExprKind::UnOp(UnOp::NotL, inner_expr_id) = expr_kind else { + return false; + }; + expr_reads_local(package, inner_expr_id, has_returned_var_id) +} + +fn assert_while_condition_return_flag_shape(source: &str, expected_ret_val: i64) { + let (store, pkg_id) = compile_return_unified(source); + let package = store.get(pkg_id); + let main_decl = find_main_decl(package); + + let CallableImpl::Spec(spec_impl) = &main_decl.implementation else { + panic!("Main must have a spec body") + }; + let body_block_id = spec_impl.body.block; + let body_block = package.get_block(body_block_id); + + let has_returned_var_id = find_top_level_local_var_id(package, body_block_id, "__has_returned"); + let ret_val_var_id = find_top_level_local_var_id(package, body_block_id, "__ret_val"); + + let while_cond_id = body_block + .stmts + .iter() + .find_map(|&stmt_id| { + let stmt_kind = package.get_stmt(stmt_id).kind.clone(); + let expr_id = match stmt_kind { + StmtKind::Expr(expr_id) | StmtKind::Semi(expr_id) => expr_id, + StmtKind::Local(_, _, _) | StmtKind::Item(_) => return None, + }; + let expr_kind = package.get_expr(expr_id).kind.clone(); + let ExprKind::While(cond_id, _body_id) = expr_kind else { + return None; + }; + Some(cond_id) + }) + .expect("expected Main body to contain rewritten while loop"); + + let cond_kind = package.get_expr(while_cond_id).kind.clone(); + let ExprKind::BinOp(BinOp::AndL, lhs_expr_id, _rhs_expr_id) = cond_kind else { + panic!("while condition should be conjoined with not __has_returned") + }; + assert!( + is_not_flag_expr(package, lhs_expr_id, has_returned_var_id), + "while condition LHS should be not __has_returned" + ); + + let trailing_stmt_id = *body_block + .stmts + .last() + .expect("Main body should have trailing expression"); + let trailing_stmt_kind = package.get_stmt(trailing_stmt_id).kind.clone(); + let StmtKind::Expr(trailing_expr_id) = trailing_stmt_kind else { + panic!("Main body should end with trailing Expr") + }; + let trailing_expr_kind = package.get_expr(trailing_expr_id).kind.clone(); + let ExprKind::If(flag_expr_id, then_expr_id, Some(else_expr_id)) = trailing_expr_kind else { + panic!("expected trailing merge expression if __has_returned ...") + }; + + assert!( + expr_reads_local(package, flag_expr_id, has_returned_var_id), + "trailing merge condition should read __has_returned" + ); + assert!( + expr_reads_local(package, then_expr_id, ret_val_var_id), + "trailing merge then-branch should read __ret_val" + ); + // After bind-then-check fix, the else branch reads __trailing_result rather than + // the literal directly. + assert!( + matches!( + package.get_expr(else_expr_id).kind, + ExprKind::Var(Res::Local(_), _) + ), + "trailing merge else-branch should read __trailing_result" + ); + + let mut saw_ret_assignment = false; + let mut saw_flag_assignment = false; + for_each_expr_in_callable_impl(package, &main_decl.implementation, &mut |_expr_id, expr| { + let expr_kind = expr.kind.clone(); + let ExprKind::Assign(lhs_expr_id, rhs_expr_id) = expr_kind else { + return; + }; + let lhs_kind = package.get_expr(lhs_expr_id).kind.clone(); + let ExprKind::Var(Res::Local(local_id), _) = lhs_kind else { + return; + }; + + if local_id == ret_val_var_id + && matches!(package.get_expr(rhs_expr_id).kind, ExprKind::Lit(Lit::Int(value)) if value == expected_ret_val) + { + saw_ret_assignment = true; + } + + if local_id == has_returned_var_id + && matches!( + package.get_expr(rhs_expr_id).kind, + ExprKind::Lit(Lit::Bool(true)) + ) + { + saw_flag_assignment = true; + } + }); + + assert!( + saw_ret_assignment, + "expected rewritten while-condition return path to assign __ret_val = {expected_ret_val}" + ); + assert!( + saw_flag_assignment, + "expected rewritten while-condition return path to set __has_returned = true" + ); +} + +#[test] +fn hoist_return_in_call_argument() { + // `Add((return 1), 2)` — Return lives in the first tuple slot of a Call. + // Disposition: documented contract. Snapshot keeps historical identifier + // spellings, while round-trip compilation confirms generated Q# validity. + check_no_returns_q_roundtrip( + indoc! {r#" + namespace Test { + function Add(a : Int, b : Int) : Int { a + b } + function Main() : Int { + let x = Add((return 1), 2); + x + } + } + "#}, + &expect![[r#" + // namespace Test + function Add(a : Int, b : Int) : Int { + body { + a + b + } + } + function Main() : Int { + body { + let _ : ((Int, Int) -> Int) = Add; + 1 + } + } + // entry + Main() + "#]], + ); +} + +#[test] +fn hoist_return_in_tuple_middle() { + // `(1, return 2, 3)` — Return in the middle of a tuple literal. + check_no_returns_q_roundtrip( + indoc! {r#" + namespace Test { + function Main() : Int { + let (a, _, _) = (1, (return 2), 3); + a + } + } + "#}, + &expect![[r#" + // namespace Test + function Main() : Int { + body { + let _ : Int = 1; + 2 + } + } + // entry + Main() + "#]], + ); +} + +#[test] +fn hoist_return_in_array_first() { + // `[return 1, 2, 3]` — Return at the head of an array literal. + check_no_returns_q_roundtrip( + indoc! {r#" + namespace Test { + function Main() : Int { + let a = [(return 1), 2, 3]; + a[0] + } + } + "#}, + &expect![[r#" + // namespace Test + function Main() : Int { + body { + 1 + } + } + // entry + Main() + "#]], + ); +} + +#[test] +fn hoist_return_in_array_repeat() { + // `[0, size = return 3]` — Return as the size argument of an + // array-repeat literal. + check_no_returns_q_roundtrip( + indoc! {r#" + namespace Test { + function Main() : Int { + let a = [0, size = (return 3)]; + a[0] + } + } + "#}, + &expect![[r#" + // namespace Test + function Main() : Int { + body { + let _ : Int = 0; + 3 + } + } + // entry + Main() + "#]], + ); +} + +#[test] +fn hoist_return_in_binop_rhs_arithmetic() { + // `a + (return 1)` — Return as the RHS of an arithmetic BinOp. + check_no_returns_q_roundtrip( + indoc! {r#" + namespace Test { + function Main() : Int { + let a = 1; + let x = a + (return 1); + x + } + } + "#}, + &expect![[r#" + // namespace Test + function Main() : Int { + body { + let a : Int = 1; + let _ : Int = a; + 1 + } + } + // entry + Main() + "#]], + ); +} + +#[test] +fn hoist_return_in_short_circuit_and_rhs() { + // `a and (return true)` — Return on the RHS of a short-circuit And. + check_no_returns_q_roundtrip( + indoc! {r#" + namespace Test { + function Main() : Bool { + true and (return true) + } + } + "#}, + &expect![[r#" + // namespace Test + function Main() : Bool { + body { + if true true else false + } + } + // entry + Main() + "#]], + ); +} + +#[test] +fn hoist_return_in_short_circuit_or_rhs() { + // `a or (return true)` — Return on the RHS of a short-circuit Or. + check_no_returns_q_roundtrip( + indoc! {r#" + namespace Test { + function Main() : Bool { + false or (return true) + } + } + "#}, + &expect![[r#" + // namespace Test + function Main() : Bool { + body { + if not false true else true + } + } + // entry + Main() + "#]], + ); +} + +#[test] +fn hoist_return_in_unop() { + // `-(return 1)` — Return as the operand of a UnOp. + check_no_returns_q( + indoc! {r#" + namespace Test { + function Main() : Int { + let x = -(return 1); + x + } + } + "#}, + &expect![[r#" + // namespace Test + function Main() : Int { + body { + 1 + } + } + // entry + Main() + "#]], + ); +} + +#[test] +fn hoist_return_in_index_expr() { + // `arr[return 0]` — Return as the index of an Index expression. + check_no_returns_q( + indoc! {r#" + namespace Test { + function Main() : Int { + let arr = [10, 20, 30]; + let i : Int = return 0; + arr[i] + } + } + "#}, + &expect![[r#" + // namespace Test + function Main() : Int { + body { + let arr : Int[] = [10, 20, 30]; + 0 + } + } + // entry + Main() + "#]], + ); +} + +#[test] +fn hoist_return_in_update_index_value() { + // `arr w/ 0 <- (return 1)` — Return as the RHS of an UpdateIndex. + check_no_returns_q( + indoc! {r#" + namespace Test { + function Main() : Int[] { + let arr = [0, 0, 0]; + let a2 = arr w/ 0 <- (return []); + a2 + } + } + "#}, + &expect![[r#" + // namespace Test + function Main() : Int[] { + body { + let arr : Int[] = [0, 0, 0]; + let _ : Int[] = arr; + let _ : Int = 0; + [] + } + } + // entry + Main() + "#]], + ); +} + +#[test] +fn hoist_return_in_struct_field() { + // `new T { F = return v }` — Return as a struct-field initializer. + check_no_returns_q( + indoc! {r#" + namespace Test { + struct Pair { First : Int, Second : Int } + function Main() : Int { + let p = new Pair { First = (return 1), Second = 2 }; + p.First + } + } + "#}, + &expect![[r#" + // namespace Test + newtype Pair = (Int, Int); + function Main() : Int { + body { + 1 + } + } + // entry + Main() + "#]], + ); +} + +#[test] +fn hoist_return_in_range_endpoint() { + // `for i in 0..(return 5) { ... }` — Return in a range endpoint, inside + // a for-loop (loop_unification lowers the range into `__range_{start,step,end}` + // locals, so the hoist sees the Return in a local-initializer position). + check_no_returns_q( + indoc! {r#" + namespace Test { + function Main() : Int { + mutable sum = 0; + for i in 0..(return 5) { + sum += i; + } + sum + } + } + "#}, + &expect![[r#" + // namespace Test + function Main() : Int { + body { + mutable sum : Int = 0; + { + let _ : Int = 0; + 5 + } + + } + } + // entry + Main() + "#]], + ); +} + +#[test] +fn hoist_return_in_fail_payload() { + // `fail (return "msg")` — Return as the payload of a fail expression. + check_no_returns_q( + indoc! {r#" + namespace Test { + function Main() : String { + fail (return "done"); + } + } + "#}, + &expect![[r#" + // namespace Test + function Main() : String { + body { + $"done" + } + } + // entry + Main() + "#]], + ); +} + +#[test] +fn hoist_return_in_string_interp() { + // `$"foo {return x} bar"` — Return inside an interpolated string segment. + check_no_returns_q( + indoc! {r#" + namespace Test { + function Main() : String { + let s = $"foo {(return "early")} bar"; + s + } + } + "#}, + &expect![[r#" + // namespace Test + function Main() : String { + body { + $"early" + } + } + // entry + Main() + "#]], + ); +} + +#[test] +fn hoist_return_in_if_condition() { + // `if (return 7) { ... }` — Return in the condition slot of an If + // expression. Condition hoisting lifts that return to statement + // boundary, so the If collapses to a block that yields `7`. + check_no_returns_q_roundtrip( + indoc! {r#" + namespace Test { + function Main() : Int { + if (return 7) { + 1 + } else { + 2 + } + } + } + "#}, + &expect![[r#" + // namespace Test + function Main() : Int { + body { + { + 7 + } + + } + } + // entry + Main() + "#]], + ); +} + +#[test] +fn hoist_return_in_while_condition() { + // `while (return 9) { ... }` — Return in the condition of a While. + // Condition hoisting lifts the return ahead of the loop, making the + // loop body unreachable. + check_no_returns_q_roundtrip( + indoc! {r#" + namespace Test { + function Main() : Int { + while (return 9) { + let _ = 0; + } + 0 + } + } + "#}, + &expect![[r#" + // namespace Test + function Main() : Int { + body { + 9 + } + } + // entry + Main() + "#]], + ); +} + +#[test] +fn hoist_return_in_while_condition_nested_if_unconditional_path() { + // Complex condition shape with nested Ifs plus an unconditional + // return-bearing left operand of `and`. + // The post-loop fallback `0` must not be accepted. + check_no_returns_q( + indoc! {r#" + namespace Test { + function Main() : Int { + while ((return 13) > 0) and (if true { + if true { + return 99; + } else { + false + } + } else { + false + }) { + let _ = 0; + } + 0 + } + } + "#}, + &expect![[r#" + // namespace Test + function Main() : Int { + body { + 13 + } + } + // entry + Main() + "#]], + ); +} + +#[test] +fn hoist_return_in_while_condition_short_circuit_and_or_unconditional_path() { + // `while (((return 17) > 0) or (false and (return 23))) and true { ... }`. + // The left side unconditionally returns before any fallthrough value can + // be observed, even with nested short-circuit `and`/`or` shape. + // The post-loop fallback `0` must not be accepted. + check_no_returns_q( + indoc! {r#" + namespace Test { + function Main() : Int { + while (((return 17) > 0) or (false and (return 23))) and true { + let _ = 0; + } + 0 + } + } + "#}, + &expect![[r#" + // namespace Test + function Main() : Int { + body { + 17 + } + } + // entry + Main() + "#]], + ); +} + +#[test] +fn while_condition_direct_nested_if_return_uses_flag_strategy() { + let source = indoc! {r#" + namespace Test { + function Main() : Int { + while if true { + if true { + return 31; + } else { + false + } + } else { + false + } { + let _ = 0; + } + 0 + } + } + "#}; + + assert_while_condition_return_flag_shape(source, 31); +} + +#[test] +fn while_condition_short_circuit_rhs_return_uses_flag_strategy() { + let source = indoc! {r#" + namespace Test { + function Main() : Int { + while true and (return 37) { + let _ = 0; + } + 0 + } + } + "#}; + + assert_while_condition_return_flag_shape(source, 37); +} + +#[test] +fn hoist_return_return_x() { + // `return (return 1)` — degenerate nested Return. + check_no_returns_q( + indoc! {r#" + namespace Test { + function Main() : Int { + return (return 1); + } + } + "#}, + &expect![[r#" + // namespace Test + function Main() : Int { + body { + 1 + } + } + // entry + Main() + "#]], + ); +} + +#[test] +fn hoist_return_chained() { + // `Add(Add((return 1), 0), 2)` — Return at a deeply nested compound + // position. Exercises the iterative fixed-point shape of the hoist. + // Disposition: documented contract. Snapshot keeps historical identifier + // spellings, while round-trip compilation confirms generated Q# validity. + check_no_returns_q_roundtrip( + indoc! {r#" + namespace Test { + function Add(a : Int, b : Int) : Int { a + b } + function Main() : Int { + let x = Add(Add((return 1), 0), 2); + x + } + } + "#}, + &expect![[r#" + // namespace Test + function Add(a : Int, b : Int) : Int { + body { + a + b + } + } + function Main() : Int { + body { + let _ : ((Int, Int) -> Int) = Add; + let _ : ((Int, Int) -> Int) = Add; + 1 + } + } + // entry + Main() + "#]], + ); +} diff --git a/source/compiler/qsc_fir_transforms/src/return_unify/normalize/tests/nested_constructs.rs b/source/compiler/qsc_fir_transforms/src/return_unify/normalize/tests/nested_constructs.rs new file mode 100644 index 0000000000..46b8fd92fa --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/return_unify/normalize/tests/nested_constructs.rs @@ -0,0 +1,717 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! Category A (nested if-without-else) and Category B (nested while/for/mixed) +//! normalization tests. + +use super::*; + +// Category A: nested if-without-else with a deep return + +#[test] +fn if_if_return_then_trailing() { + // Depth-2 if-without-else leaf return with a trailing continuation. + check_no_returns_q( + indoc! {r#" + namespace Test { + import Std.Measurement.*; + @EntryPoint() + operation Main() : Int { + use q = Qubit(); + if M(q) == One { + if M(q) == Zero { + return 1; + } + } + 2 + } + } + "#}, + &expect![[r#" + // namespace Test + operation Main() : Int { + body { + mutable __has_returned : Bool = false; + mutable __ret_val : Int = 0; + let q : Qubit = __quantum__rt__qubit_allocate(); + if M(q) == One { + if M(q) == Zero { + { + let + @generated_ident_41 : Int = 1; + __quantum__rt__qubit_release(q); + { + __ret_val = + @generated_ident_41; + __has_returned = true; + }; + }; + } + + } + + let + @generated_ident_53 : Int = if not __has_returned { + 2 + } else { + 0 + }; + if not __has_returned { + __quantum__rt__qubit_release(q); + }; + let __trailing_result : Int = + @generated_ident_53; + if __has_returned __ret_val else __trailing_result + } + } + function Length(a : Pauli[]) : Int { + body intrinsic; + } + function Length(a : Qubit[]) : Int { + body intrinsic; + } + // entry + Main() + "#]], + ); +} + +#[test] +fn if_if_return_no_trailing_unit() { + // Unit-typed callable version of the same shape. + check_no_returns_q( + indoc! {r#" + namespace Test { + import Std.Measurement.*; + @EntryPoint() + operation Main() : Unit { + use q = Qubit(); + if M(q) == One { + if M(q) == Zero { + return (); + } + } + } + } + "#}, + &expect![[r#" + // namespace Test + operation Main() : Unit { + body { + mutable __has_returned : Bool = false; + mutable __ret_val : Unit = (); + let q : Qubit = __quantum__rt__qubit_allocate(); + let + @generated_ident_51 : Unit = if M(q) == One { + if M(q) == Zero { + { + let + @generated_ident_39 : Unit = (); + __quantum__rt__qubit_release(q); + { + __ret_val = + @generated_ident_39; + __has_returned = true; + }; + }; + } + + }; + if not __has_returned { + __quantum__rt__qubit_release(q); + }; + let __trailing_result : Unit = + @generated_ident_51; + if __has_returned __ret_val else __trailing_result + } + } + function Length(a : Pauli[]) : Int { + body intrinsic; + } + function Length(a : Qubit[]) : Int { + body intrinsic; + } + // entry + Main() + "#]], + ); +} + +#[test] +fn if_if_return_sibling_stmt_before_if() { + // Statements precede the leaky if-if-return; their side effects must + // survive the flag rewrite. + check_no_returns_q( + indoc! {r#" + namespace Test { + import Std.Measurement.*; + @EntryPoint() + operation Main() : Int { + use q = Qubit(); + mutable acc = 0; + acc += 10; + if M(q) == One { + if M(q) == Zero { + return acc; + } + } + acc + 1 + } + } + "#}, + &expect![[r#" + // namespace Test + operation Main() : Int { + body { + mutable __has_returned : Bool = false; + mutable __ret_val : Int = 0; + let q : Qubit = __quantum__rt__qubit_allocate(); + mutable acc : Int = 0; + acc += 10; + if M(q) == One { + if M(q) == Zero { + { + let + @generated_ident_51 : Int = acc; + __quantum__rt__qubit_release(q); + { + __ret_val = + @generated_ident_51; + __has_returned = true; + }; + }; + } + + } + + let + @generated_ident_63 : Int = if not __has_returned { + acc + 1 + } else { + 0 + }; + if not __has_returned { + __quantum__rt__qubit_release(q); + }; + let __trailing_result : Int = + @generated_ident_63; + if __has_returned __ret_val else __trailing_result + } + } + function Length(a : Pauli[]) : Int { + body intrinsic; + } + function Length(a : Qubit[]) : Int { + body intrinsic; + } + // entry + Main() + "#]], + ); +} + +#[test] +fn if_if_return_inside_block_wrapper() { + // Block wrapper around the leaky if-if-return. + check_no_returns_q( + indoc! {r#" + namespace Test { + import Std.Measurement.*; + @EntryPoint() + operation Main() : Int { + use q = Qubit(); + { + if M(q) == One { + if M(q) == Zero { + return 1; + } + } + }; + 2 + } + } + "#}, + &expect![[r#" + // namespace Test + operation Main() : Int { + body { + mutable __has_returned : Bool = false; + mutable __ret_val : Int = 0; + let q : Qubit = __quantum__rt__qubit_allocate(); + { + if M(q) == One { + if M(q) == Zero { + { + let + @generated_ident_44 : Int = 1; + __quantum__rt__qubit_release(q); + { + __ret_val = + @generated_ident_44; + __has_returned = true; + }; + }; + } + + } + + }; + let + @generated_ident_56 : Int = if not __has_returned { + 2 + } else { + 0 + }; + if not __has_returned { + __quantum__rt__qubit_release(q); + }; + let __trailing_result : Int = + @generated_ident_56; + if __has_returned __ret_val else __trailing_result + } + } + function Length(a : Pauli[]) : Int { + body intrinsic; + } + function Length(a : Qubit[]) : Int { + body intrinsic; + } + // entry + Main() + "#]], + ); +} + +#[test] +fn if_elseif_if_return_deep() { + // if / elif / if with deepest return in the last arm. + check_no_returns_q( + indoc! {r#" + namespace Test { + import Std.Measurement.*; + @EntryPoint() + operation Main() : Int { + use q = Qubit(); + if M(q) == One { + 1 + } elif M(q) == Zero { + if M(q) == One { + return 2; + } + 3 + } else { + 4 + } + } + } + "#}, + &expect![[r#" + // namespace Test + operation Main() : Int { + body { + let q : Qubit = __quantum__rt__qubit_allocate(); + if not M(q) == One if M(q) == Zero { + if M(q) == One { + { + let + @generated_ident_55 : Int = 2; + __quantum__rt__qubit_release(q); + @generated_ident_55 + } + + } else { + 3 + } + + } else { + 4 + } else { + let + @generated_ident_67 : Int = { + 1 + }; + __quantum__rt__qubit_release(q); + @generated_ident_67 + } + + } + } + function Length(a : Pauli[]) : Int { + body intrinsic; + } + function Length(a : Qubit[]) : Int { + body intrinsic; + } + // entry + Main() + "#]], + ); +} + +// Category B: nested while / for / mixed with a deep return + +#[test] +fn while_while_return_deep() { + // Depth-2 nested whiles with the return in the innermost body. + check_no_returns_q( + indoc! {r#" + namespace Test { + import Std.Measurement.*; + @EntryPoint() + operation Main() : Int { + mutable i = 0; + mutable j = 0; + use q = Qubit(); + while i < 2 { + while j < 2 { + if M(q) == One { + return 7; + } + j += 1; + } + i += 1; + } + 0 + } + } + "#}, + &expect![[r#" + // namespace Test + operation Main() : Int { + body { + mutable __has_returned : Bool = false; + mutable __ret_val : Int = 0; + mutable i : Int = 0; + mutable j : Int = 0; + let q : Qubit = __quantum__rt__qubit_allocate(); + while not __has_returned and i < 2 { + while not __has_returned and j < 2 { + if M(q) == One { + { + let + @generated_ident_60 : Int = 7; + __quantum__rt__qubit_release(q); + { + __ret_val = + @generated_ident_60; + __has_returned = true; + }; + }; + } + + if not __has_returned { + j += 1; + }; + } + + if not __has_returned { + i += 1; + }; + } + + let + @generated_ident_72 : Int = { + 0 + }; + if not __has_returned { + __quantum__rt__qubit_release(q); + }; + let __trailing_result : Int = + @generated_ident_72; + if __has_returned __ret_val else __trailing_result + } + } + function Length(a : Pauli[]) : Int { + body intrinsic; + } + function Length(a : Qubit[]) : Int { + body intrinsic; + } + // entry + Main() + "#]], + ); +} + +#[test] +fn while_for_if_return_deep() { + // while / for / if mixed nesting. + check_no_returns_q( + indoc! {r#" + namespace Test { + import Std.Measurement.*; + @EntryPoint() + operation Main() : Int { + mutable i = 0; + use q = Qubit(); + while i < 3 { + for j in 0..2 { + if M(q) == One { + return i * 10 + j; + } + } + i += 1; + } + -1 + } + } + "#}, + &expect![[r#" + // namespace Test + operation Main() : Int { + body { + mutable __has_returned : Bool = false; + mutable __ret_val : Int = 0; + mutable i : Int = 0; + let q : Qubit = __quantum__rt__qubit_allocate(); + while not __has_returned and i < 3 { + { + let + @range_id_54 : Range = 0..2; + mutable + @index_id_57 : Int = + @range_id_54::Start; + let + @step_id_62 : Int = + @range_id_54::Step; + let + @end_id_67 : Int = + @range_id_54::End; + while not __has_returned and + @step_id_62 > 0 and + @index_id_57 <= + @end_id_67 or + @step_id_62 < 0 and + @index_id_57 >= + @end_id_67 { + let j : Int = + @index_id_57; + if M(q) == One { + { + let + @generated_ident_102 : Int = i * 10 + j; + __quantum__rt__qubit_release(q); + { + __ret_val = + @generated_ident_102; + __has_returned = true; + }; + }; + } + + if not __has_returned { + @index_id_57 += + @step_id_62; + }; + } + + } + + if not __has_returned { + i += 1; + }; + } + + let + @generated_ident_114 : Int = if not __has_returned { + -1 + } else { + 0 + }; + if not __has_returned { + __quantum__rt__qubit_release(q); + }; + let __trailing_result : Int = + @generated_ident_114; + if __has_returned __ret_val else __trailing_result + } + } + function Length(a : Pauli[]) : Int { + body intrinsic; + } + function Length(a : Qubit[]) : Int { + body intrinsic; + } + // entry + Main() + "#]], + ); +} + +#[test] +fn while_inside_if_without_else_return() { + // Leaky if (no else) wrapping a while whose body returns. + check_no_returns_q( + indoc! {r#" + namespace Test { + import Std.Measurement.*; + @EntryPoint() + operation Main() : Int { + mutable i = 0; + use q = Qubit(); + if M(q) == One { + while i < 3 { + if M(q) == Zero { + return i; + } + i += 1; + } + } + -1 + } + } + "#}, + &expect![[r#" + // namespace Test + operation Main() : Int { + body { + mutable __has_returned : Bool = false; + mutable __ret_val : Int = 0; + mutable i : Int = 0; + let q : Qubit = __quantum__rt__qubit_allocate(); + if M(q) == One { + while not __has_returned and i < 3 { + if M(q) == Zero { + { + let + @generated_ident_56 : Int = i; + __quantum__rt__qubit_release(q); + { + __ret_val = + @generated_ident_56; + __has_returned = true; + }; + }; + } + + if not __has_returned { + i += 1; + }; + } + + } + + let + @generated_ident_68 : Int = if not __has_returned { + -1 + } else { + 0 + }; + if not __has_returned { + __quantum__rt__qubit_release(q); + }; + let __trailing_result : Int = + @generated_ident_68; + if __has_returned __ret_val else __trailing_result + } + } + function Length(a : Pauli[]) : Int { + body intrinsic; + } + function Length(a : Qubit[]) : Int { + body intrinsic; + } + // entry + Main() + "#]], + ); +} + +#[test] +fn for_inside_if_without_else_return() { + // Leaky if (no else) wrapping a for whose body returns. + check_no_returns_q( + indoc! {r#" + namespace Test { + import Std.Measurement.*; + @EntryPoint() + operation Main() : Int { + use q = Qubit(); + if M(q) == One { + for j in 0..2 { + if M(q) == Zero { + return j; + } + } + } + -1 + } + } + "#}, + &expect![[r#" + // namespace Test + operation Main() : Int { + body { + mutable __has_returned : Bool = false; + mutable __ret_val : Int = 0; + let q : Qubit = __quantum__rt__qubit_allocate(); + if M(q) == One { + { + let + @range_id_45 : Range = 0..2; + mutable + @index_id_48 : Int = + @range_id_45::Start; + let + @step_id_53 : Int = + @range_id_45::Step; + let + @end_id_58 : Int = + @range_id_45::End; + while not __has_returned and + @step_id_53 > 0 and + @index_id_48 <= + @end_id_58 or + @step_id_53 < 0 and + @index_id_48 >= + @end_id_58 { + let j : Int = + @index_id_48; + if M(q) == Zero { + { + let + @generated_ident_93 : Int = j; + __quantum__rt__qubit_release(q); + { + __ret_val = + @generated_ident_93; + __has_returned = true; + }; + }; + } + + if not __has_returned { + @index_id_48 += + @step_id_53; + }; + } + + } + + } + + let + @generated_ident_105 : Int = if not __has_returned { + -1 + } else { + 0 + }; + if not __has_returned { + __quantum__rt__qubit_release(q); + }; + let __trailing_result : Int = + @generated_ident_105; + if __has_returned __ret_val else __trailing_result + } + } + function Length(a : Pauli[]) : Int { + body intrinsic; + } + function Length(a : Qubit[]) : Int { + body intrinsic; + } + // entry + Main() + "#]], + ); +} diff --git a/source/compiler/qsc_fir_transforms/src/return_unify/normalize/tests/regression_and_depth.rs b/source/compiler/qsc_fir_transforms/src/return_unify/normalize/tests/regression_and_depth.rs new file mode 100644 index 0000000000..5601f199b0 --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/return_unify/normalize/tests/regression_and_depth.rs @@ -0,0 +1,832 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! Predicate boundary, Category-C regression, continuation threading, +//! depth-4, use-scope carrier, and if-elseif boundary tests. + +use super::*; + +// Predicate-boundary: trivially-structured shapes stay structured + +#[test] +fn single_bare_return_at_end_stays_structured() { + // A single trailing `return` should NOT trigger the flag strategy. + // The structured path rewrites it into the trailing value with no + // `__has_returned` / `__ret_val` locals. + check_structure( + indoc! {r#" + namespace Test { + @EntryPoint() + function Main() : Int { + return 1; + } + } + "#}, + &["Main"], + &expect![[r#" + callable Main: input_ty=Unit, output_ty=Int + body: block_ty=Int + [0] Expr Lit(Int(1))"#]], + ); +} + +#[test] +fn if_then_return_else_return_at_end_stays_structured() { + // `if c { return a; } else { return b; }` should also stay structured: + // both branches return so the flag strategy is unnecessary. + check_structure( + indoc! {r#" + namespace Test { + import Std.Measurement.*; + @EntryPoint() + operation Main() : Int { + use q = Qubit(); + if M(q) == One { + return 1; + } else { + return 2; + } + } + } + "#}, + &["Main"], + &expect![[r#" + callable Main: input_ty=Unit, output_ty=Int + body: block_ty=Int + [0] Local(Immutable, q: Qubit): Call[ty=Qubit] + [1] Expr If(cond=BinOp(Eq)[ty=Bool], then=Block[ty=Int], else=Block[ty=Int])"#]], + ); +} + +// Category-C regression: inner while must terminate after rewrite + +#[test] +fn nested_while_inner_only_exit_is_return_terminates() { + // The inner `while true` only exits via `return 1`. After return + // unification its condition MUST be conjoined with `not __has_returned` + // so the rewrite preserves termination. + check_no_returns_q( + indoc! {r#" + namespace Test { + import Std.Measurement.*; + @EntryPoint() + operation Main() : Int { + use q = Qubit(); + mutable outer = true; + while outer { + while true { + if M(q) == One { + return 1; + } + } + } + 0 + } + } + "#}, + &expect![[r#" + // namespace Test + operation Main() : Int { + body { + mutable __has_returned : Bool = false; + mutable __ret_val : Int = 0; + let q : Qubit = __quantum__rt__qubit_allocate(); + mutable outer : Bool = true; + while not __has_returned and outer { + while not __has_returned and true { + if M(q) == One { + { + let + @generated_ident_44 : Int = 1; + __quantum__rt__qubit_release(q); + { + __ret_val = + @generated_ident_44; + __has_returned = true; + }; + }; + } + + } + + } + + let + @generated_ident_56 : Int = { + 0 + }; + if not __has_returned { + __quantum__rt__qubit_release(q); + }; + let __trailing_result : Int = + @generated_ident_56; + if __has_returned __ret_val else __trailing_result + } + } + function Length(a : Pauli[]) : Int { + body intrinsic; + } + function Length(a : Qubit[]) : Int { + body intrinsic; + } + // entry + Main() + "#]], + ); +} + +#[test] +#[allow(clippy::too_many_lines)] +fn nested_for_inner_body_hits_return() { + // For-loops desugar to while. The desugared inner while's condition + // must also pick up the flag guard. + check_no_returns_q( + indoc! {r#" + namespace Test { + import Std.Measurement.*; + @EntryPoint() + operation Main() : Int { + use q = Qubit(); + for _ in 0..100 { + for _ in 0..100 { + if M(q) == One { + return 1; + } + } + } + 0 + } + } + "#}, + &expect![[r#" + // namespace Test + operation Main() : Int { + body { + mutable __has_returned : Bool = false; + mutable __ret_val : Int = 0; + let q : Qubit = __quantum__rt__qubit_allocate(); + { + let + @range_id_84 : Range = 0..100; + mutable + @index_id_87 : Int = + @range_id_84::Start; + let + @step_id_92 : Int = + @range_id_84::Step; + let + @end_id_97 : Int = + @range_id_84::End; + while not __has_returned and + @step_id_92 > 0 and + @index_id_87 <= + @end_id_97 or + @step_id_92 < 0 and + @index_id_87 >= + @end_id_97 { + let _ : Int = + @index_id_87; + { + let + @range_id_41 : Range = 0..100; + mutable + @index_id_44 : Int = + @range_id_41::Start; + let + @step_id_49 : Int = + @range_id_41::Step; + let + @end_id_54 : Int = + @range_id_41::End; + while not __has_returned and + @step_id_49 > 0 and + @index_id_44 <= + @end_id_54 or + @step_id_49 < 0 and + @index_id_44 >= + @end_id_54 { + let _ : Int = + @index_id_44; + if M(q) == One { + { + let + @generated_ident_132 : Int = 1; + __quantum__rt__qubit_release(q); + { + __ret_val = + @generated_ident_132; + __has_returned = true; + }; + }; + } + + if not __has_returned { + @index_id_44 += + @step_id_49; + }; + } + + } + + if not __has_returned { + @index_id_87 += + @step_id_92; + }; + } + + } + + let + @generated_ident_144 : Int = { + 0 + }; + if not __has_returned { + __quantum__rt__qubit_release(q); + }; + let __trailing_result : Int = + @generated_ident_144; + if __has_returned __ret_val else __trailing_result + } + } + function Length(a : Pauli[]) : Int { + body intrinsic; + } + function Length(a : Qubit[]) : Int { + body intrinsic; + } + // entry + Main() + "#]], + ); +} + +// Continuation-threading regression + +#[test] +fn continuation_value_is_observed_when_inner_return_not_taken() { + // When the inner `return` is not taken, the outer block's trailing + // value `2` (not a synthesized default) must be observed. + check_no_returns_q( + indoc! {r#" + namespace Test { + import Std.Measurement.*; + @EntryPoint() + operation Main() : Int { + use q = Qubit(); + if M(q) == One { + if M(q) == Zero { + return 1; + } + } + 2 + } + } + "#}, + &expect![[r#" + // namespace Test + operation Main() : Int { + body { + mutable __has_returned : Bool = false; + mutable __ret_val : Int = 0; + let q : Qubit = __quantum__rt__qubit_allocate(); + if M(q) == One { + if M(q) == Zero { + { + let + @generated_ident_41 : Int = 1; + __quantum__rt__qubit_release(q); + { + __ret_val = + @generated_ident_41; + __has_returned = true; + }; + }; + } + + } + + let + @generated_ident_53 : Int = if not __has_returned { + 2 + } else { + 0 + }; + if not __has_returned { + __quantum__rt__qubit_release(q); + }; + let __trailing_result : Int = + @generated_ident_53; + if __has_returned __ret_val else __trailing_result + } + } + function Length(a : Pauli[]) : Int { + body intrinsic; + } + function Length(a : Qubit[]) : Int { + body intrinsic; + } + // entry + Main() + "#]], + ); +} + +// Depth-4 regressions + +#[test] +fn four_level_if_if_if_if_return_deepest() { + // Pure if-without-else chain at depth 4 with the return at the leaf. + check_no_returns_q( + indoc! {r#" + namespace Test { + import Std.Measurement.*; + @EntryPoint() + operation Main() : Int { + use q = Qubit(); + if M(q) == One { + if M(q) == Zero { + if M(q) == One { + if M(q) == Zero { + return 1; + } + } + } + } + 2 + } + } + "#}, + &expect![[r#" + // namespace Test + operation Main() : Int { + body { + mutable __has_returned : Bool = false; + mutable __ret_val : Int = 0; + let q : Qubit = __quantum__rt__qubit_allocate(); + if M(q) == One { + if M(q) == Zero { + if M(q) == One { + if M(q) == Zero { + { + let + @generated_ident_59 : Int = 1; + __quantum__rt__qubit_release(q); + { + __ret_val = + @generated_ident_59; + __has_returned = true; + }; + }; + } + + } + + } + + } + + let + @generated_ident_71 : Int = if not __has_returned { + 2 + } else { + 0 + }; + if not __has_returned { + __quantum__rt__qubit_release(q); + }; + let __trailing_result : Int = + @generated_ident_71; + if __has_returned __ret_val else __trailing_result + } + } + function Length(a : Pauli[]) : Int { + body intrinsic; + } + function Length(a : Qubit[]) : Int { + body intrinsic; + } + // entry + Main() + "#]], + ); +} + +#[test] +fn four_level_while_while_while_while_return_deepest() { + // Pure nested whiles at depth 4; pins the Category-C fix recursion. + check_no_returns_q( + indoc! {r#" + namespace Test { + import Std.Measurement.*; + @EntryPoint() + operation Main() : Int { + mutable i = 0; + mutable j = 0; + mutable k = 0; + mutable l = 0; + use q = Qubit(); + while i < 2 { + while j < 2 { + while k < 2 { + while l < 2 { + if M(q) == One { + return 9; + } + l += 1; + } + k += 1; + } + j += 1; + } + i += 1; + } + 0 + } + } + "#}, + &expect![[r#" + // namespace Test + operation Main() : Int { + body { + mutable __has_returned : Bool = false; + mutable __ret_val : Int = 0; + mutable i : Int = 0; + mutable j : Int = 0; + mutable k : Int = 0; + mutable l : Int = 0; + let q : Qubit = __quantum__rt__qubit_allocate(); + while not __has_returned and i < 2 { + while not __has_returned and j < 2 { + while not __has_returned and k < 2 { + while not __has_returned and l < 2 { + if M(q) == One { + { + let + @generated_ident_88 : Int = 9; + __quantum__rt__qubit_release(q); + { + __ret_val = + @generated_ident_88; + __has_returned = true; + }; + }; + } + + if not __has_returned { + l += 1; + }; + } + + if not __has_returned { + k += 1; + }; + } + + if not __has_returned { + j += 1; + }; + } + + if not __has_returned { + i += 1; + }; + } + + let + @generated_ident_100 : Int = { + 0 + }; + if not __has_returned { + __quantum__rt__qubit_release(q); + }; + let __trailing_result : Int = + @generated_ident_100; + if __has_returned __ret_val else __trailing_result + } + } + function Length(a : Pauli[]) : Int { + body intrinsic; + } + function Length(a : Qubit[]) : Int { + body intrinsic; + } + // entry + Main() + "#]], + ); +} + +#[test] +fn four_level_if_while_for_if_return_deepest() { + // Mixed shape at depth 4 with the return in the deepest `if`. + check_no_returns_q( + indoc! {r#" + namespace Test { + import Std.Measurement.*; + @EntryPoint() + operation Main() : Int { + mutable i = 0; + use q = Qubit(); + if M(q) == One { + while i < 3 { + for j in 0..2 { + if M(q) == Zero { + return i * 100 + j; + } + } + i += 1; + } + } + -1 + } + } + "#}, + &expect![[r#" + // namespace Test + operation Main() : Int { + body { + mutable __has_returned : Bool = false; + mutable __ret_val : Int = 0; + mutable i : Int = 0; + let q : Qubit = __quantum__rt__qubit_allocate(); + if M(q) == One { + while not __has_returned and i < 3 { + { + let + @range_id_63 : Range = 0..2; + mutable + @index_id_66 : Int = + @range_id_63::Start; + let + @step_id_71 : Int = + @range_id_63::Step; + let + @end_id_76 : Int = + @range_id_63::End; + while not __has_returned and + @step_id_71 > 0 and + @index_id_66 <= + @end_id_76 or + @step_id_71 < 0 and + @index_id_66 >= + @end_id_76 { + let j : Int = + @index_id_66; + if M(q) == Zero { + { + let + @generated_ident_111 : Int = i * 100 + j; + __quantum__rt__qubit_release(q); + { + __ret_val = + @generated_ident_111; + __has_returned = true; + }; + }; + } + + if not __has_returned { + @index_id_66 += + @step_id_71; + }; + } + + } + + if not __has_returned { + i += 1; + }; + } + + } + + let + @generated_ident_123 : Int = if not __has_returned { + -1 + } else { + 0 + }; + if not __has_returned { + __quantum__rt__qubit_release(q); + }; + let __trailing_result : Int = + @generated_ident_123; + if __has_returned __ret_val else __trailing_result + } + } + function Length(a : Pauli[]) : Int { + body intrinsic; + } + function Length(a : Qubit[]) : Int { + body intrinsic; + } + // entry + Main() + "#]], + ); +} + +// `use`-scope carriers and `if-elseif` boundary tests + +#[test] +fn use_scope_wraps_nested_if_return_deep() { + // `use q = Qubit()` scope carrier wrapping a leaky if-if-return. + check_no_returns_q( + indoc! {r#" + namespace Test { + import Std.Measurement.*; + @EntryPoint() + operation Main() : Int { + use q = Qubit(); + if M(q) == One { + if M(q) == Zero { + return 1; + } + } + 2 + } + } + "#}, + &expect![[r#" + // namespace Test + operation Main() : Int { + body { + mutable __has_returned : Bool = false; + mutable __ret_val : Int = 0; + let q : Qubit = __quantum__rt__qubit_allocate(); + if M(q) == One { + if M(q) == Zero { + { + let + @generated_ident_41 : Int = 1; + __quantum__rt__qubit_release(q); + { + __ret_val = + @generated_ident_41; + __has_returned = true; + }; + }; + } + + } + + let + @generated_ident_53 : Int = if not __has_returned { + 2 + } else { + 0 + }; + if not __has_returned { + __quantum__rt__qubit_release(q); + }; + let __trailing_result : Int = + @generated_ident_53; + if __has_returned __ret_val else __trailing_result + } + } + function Length(a : Pauli[]) : Int { + body intrinsic; + } + function Length(a : Qubit[]) : Int { + body intrinsic; + } + // entry + Main() + "#]], + ); +} + +#[test] +fn if_elseif_elseif_else_return_in_last_arm() { + // if-elseif-elseif-else ladder at depth 3 with return in the last arm. + check_no_returns_q( + indoc! {r#" + namespace Test { + import Std.Measurement.*; + @EntryPoint() + operation Main() : Int { + use q = Qubit(); + if M(q) == One { + 1 + } elif M(q) == Zero { + 2 + } elif M(q) == One { + 3 + } else { + return 4; + } + } + } + "#}, + &expect![[r#" + // namespace Test + operation Main() : Int { + body { + let q : Qubit = __quantum__rt__qubit_allocate(); + if not M(q) == One if M(q) == Zero { + 2 + } elif M(q) == One { + 3 + } else { + { + let + @generated_ident_54 : Int = 4; + __quantum__rt__qubit_release(q); + @generated_ident_54 + } + + } else { + let + @generated_ident_66 : Int = { + 1 + }; + __quantum__rt__qubit_release(q); + @generated_ident_66 + } + + } + } + function Length(a : Pauli[]) : Int { + body intrinsic; + } + function Length(a : Qubit[]) : Int { + body intrinsic; + } + // entry + Main() + "#]], + ); +} + +#[test] +fn nested_use_scope_return_in_inner_body() { + // Two `use` scopes nested inside an if-without-else with a deep return. + check_no_returns_q( + indoc! {r#" + namespace Test { + import Std.Measurement.*; + @EntryPoint() + operation Main() : Int { + use q0 = Qubit(); + if M(q0) == One { + use q1 = Qubit(); + if M(q1) == Zero { + return 1; + } + } + 0 + } + } + "#}, + &expect![[r#" + // namespace Test + operation Main() : Int { + body { + mutable __has_returned : Bool = false; + mutable __ret_val : Int = 0; + let q0 : Qubit = __quantum__rt__qubit_allocate(); + if M(q0) == One { + let q1 : Qubit = __quantum__rt__qubit_allocate(); + let + @generated_ident_66 : Unit = if M(q1) == Zero { + { + let + @generated_ident_50 : Int = 1; + __quantum__rt__qubit_release(q1); + __quantum__rt__qubit_release(q0); + { + __ret_val = + @generated_ident_50; + __has_returned = true; + }; + }; + }; + if not __has_returned { + __quantum__rt__qubit_release(q1); + }; + @generated_ident_66 + } + + let + @generated_ident_75 : Int = { + 0 + }; + if not __has_returned { + __quantum__rt__qubit_release(q0); + }; + let __trailing_result : Int = + @generated_ident_75; + if __has_returned __ret_val else __trailing_result + } + } + function Length(a : Pauli[]) : Int { + body intrinsic; + } + function Length(a : Qubit[]) : Int { + body intrinsic; + } + // entry + Main() + "#]], + ); +} diff --git a/source/compiler/qsc_fir_transforms/src/return_unify/normalize/tests/three_level.rs b/source/compiler/qsc_fir_transforms/src/return_unify/normalize/tests/three_level.rs new file mode 100644 index 0000000000..3c1a5afbf1 --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/return_unify/normalize/tests/three_level.rs @@ -0,0 +1,620 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! Three-level nesting tests: pure if/else/while/for combinations. + +use super::*; + +// The following tests nest block-bearing constructs three levels deep with +// `return`s placed at a variety of positions. They exercise the interaction +// between the hoist pre-pass and the strategy pass when rewrites must reach +// into deeply nested `Block`/`If`/`While`/`for` bodies. The outer callable +// uses `@EntryPoint() operation Main() : Int` so that any dynamic branch +// (driven by `M(q)`) is legal at the strategy-pass level. + +#[test] +fn three_level_if_if_if_return_in_deepest_then() { + // if / if / if -> return at the innermost then + check_no_returns_q( + indoc! {r#" + namespace Test { + import Std.Measurement.*; + @EntryPoint() + operation Main() : Int { + use q = Qubit(); + if M(q) == One { + if M(q) == Zero { + if M(q) == One { + return 1; + } + } + } + 2 + } + } + "#}, + &expect![[r#" + // namespace Test + operation Main() : Int { + body { + mutable __has_returned : Bool = false; + mutable __ret_val : Int = 0; + let q : Qubit = __quantum__rt__qubit_allocate(); + if M(q) == One { + if M(q) == Zero { + if M(q) == One { + { + let + @generated_ident_50 : Int = 1; + __quantum__rt__qubit_release(q); + { + __ret_val = + @generated_ident_50; + __has_returned = true; + }; + }; + } + + } + + } + + let + @generated_ident_62 : Int = if not __has_returned { + 2 + } else { + 0 + }; + if not __has_returned { + __quantum__rt__qubit_release(q); + }; + let __trailing_result : Int = + @generated_ident_62; + if __has_returned __ret_val else __trailing_result + } + } + function Length(a : Pauli[]) : Int { + body intrinsic; + } + function Length(a : Qubit[]) : Int { + body intrinsic; + } + // entry + Main() + "#]], + ); +} + +#[test] +fn three_level_if_else_chain_return_in_deepest_else() { + // if { ... } else { if { ... } else { if c { x } else { return v } } } + check_no_returns_q( + indoc! {r#" + namespace Test { + import Std.Measurement.*; + @EntryPoint() + operation Main() : Int { + use q = Qubit(); + if M(q) == One { + 1 + } else { + if M(q) == Zero { + 2 + } else { + if M(q) == One { + 3 + } else { + return 4; + } + } + } + } + } + "#}, + &expect![[r#" + // namespace Test + operation Main() : Int { + body { + let q : Qubit = __quantum__rt__qubit_allocate(); + if not M(q) == One { + if not M(q) == Zero { + if not M(q) == One { + { + let + @generated_ident_60 : Int = 4; + __quantum__rt__qubit_release(q); + @generated_ident_60 + } + + } else { + 3 + } + + } else { + 2 + } + + } else { + let + @generated_ident_72 : Int = { + 1 + }; + __quantum__rt__qubit_release(q); + @generated_ident_72 + } + + } + } + function Length(a : Pauli[]) : Int { + body intrinsic; + } + function Length(a : Qubit[]) : Int { + body intrinsic; + } + // entry + Main() + "#]], + ); +} + +#[test] +fn three_level_while_while_while_return_deep() { + // while / while / while -> return deep in the innermost body + check_no_returns_q( + indoc! {r#" + namespace Test { + import Std.Measurement.*; + @EntryPoint() + operation Main() : Int { + mutable i = 0; + mutable j = 0; + mutable k = 0; + use q = Qubit(); + while i < 2 { + while j < 2 { + while k < 2 { + if M(q) == One { + return 7; + } + k += 1; + } + j += 1; + } + i += 1; + } + 0 + } + } + "#}, + &expect![[r#" + // namespace Test + operation Main() : Int { + body { + mutable __has_returned : Bool = false; + mutable __ret_val : Int = 0; + mutable i : Int = 0; + mutable j : Int = 0; + mutable k : Int = 0; + let q : Qubit = __quantum__rt__qubit_allocate(); + while not __has_returned and i < 2 { + while not __has_returned and j < 2 { + while not __has_returned and k < 2 { + if M(q) == One { + { + let + @generated_ident_74 : Int = 7; + __quantum__rt__qubit_release(q); + { + __ret_val = + @generated_ident_74; + __has_returned = true; + }; + }; + } + + if not __has_returned { + k += 1; + }; + } + + if not __has_returned { + j += 1; + }; + } + + if not __has_returned { + i += 1; + }; + } + + let + @generated_ident_86 : Int = { + 0 + }; + if not __has_returned { + __quantum__rt__qubit_release(q); + }; + let __trailing_result : Int = + @generated_ident_86; + if __has_returned __ret_val else __trailing_result + } + } + function Length(a : Pauli[]) : Int { + body intrinsic; + } + function Length(a : Qubit[]) : Int { + body intrinsic; + } + // entry + Main() + "#]], + ); +} + +#[test] +#[allow(clippy::too_many_lines)] +fn three_level_for_for_for_return_deep() { + // for / for / for -> return deep inside the innermost body + check_no_returns_q( + indoc! {r#" + namespace Test { + import Std.Measurement.*; + @EntryPoint() + operation Main() : Int { + use q = Qubit(); + for a in 0..2 { + for b in 0..2 { + for c in 0..2 { + if M(q) == One { + return a + b + c; + } + } + } + } + 0 + } + } + "#}, + &expect![[r#" + // namespace Test + operation Main() : Int { + body { + mutable __has_returned : Bool = false; + mutable __ret_val : Int = 0; + let q : Qubit = __quantum__rt__qubit_allocate(); + { + let + @range_id_141 : Range = 0..2; + mutable + @index_id_144 : Int = + @range_id_141::Start; + let + @step_id_149 : Int = + @range_id_141::Step; + let + @end_id_154 : Int = + @range_id_141::End; + while not __has_returned and + @step_id_149 > 0 and + @index_id_144 <= + @end_id_154 or + @step_id_149 < 0 and + @index_id_144 >= + @end_id_154 { + let a : Int = + @index_id_144; + { + let + @range_id_98 : Range = 0..2; + mutable + @index_id_101 : Int = + @range_id_98::Start; + let + @step_id_106 : Int = + @range_id_98::Step; + let + @end_id_111 : Int = + @range_id_98::End; + while not __has_returned and + @step_id_106 > 0 and + @index_id_101 <= + @end_id_111 or + @step_id_106 < 0 and + @index_id_101 >= + @end_id_111 { + let b : Int = + @index_id_101; + { + let + @range_id_55 : Range = 0..2; + mutable + @index_id_58 : Int = + @range_id_55::Start; + let + @step_id_63 : Int = + @range_id_55::Step; + let + @end_id_68 : Int = + @range_id_55::End; + while not __has_returned and + @step_id_63 > 0 and + @index_id_58 <= + @end_id_68 or + @step_id_63 < 0 and + @index_id_58 >= + @end_id_68 { + let c : Int = + @index_id_58; + if M(q) == One { + { + let + @generated_ident_189 : Int = a + b + c; + __quantum__rt__qubit_release(q); + { + __ret_val = + @generated_ident_189; + __has_returned = true; + }; + }; + } + + if not __has_returned { + @index_id_58 += + @step_id_63; + }; + } + + } + + if not __has_returned { + @index_id_101 += + @step_id_106; + }; + } + + } + + if not __has_returned { + @index_id_144 += + @step_id_149; + }; + } + + } + + let + @generated_ident_201 : Int = { + 0 + }; + if not __has_returned { + __quantum__rt__qubit_release(q); + }; + let __trailing_result : Int = + @generated_ident_201; + if __has_returned __ret_val else __trailing_result + } + } + function Length(a : Pauli[]) : Int { + body intrinsic; + } + function Length(a : Qubit[]) : Int { + body intrinsic; + } + // entry + Main() + "#]], + ); +} + +#[test] +fn three_level_for_while_if_return_deep() { + // for / while / if -> return inside the if + check_no_returns_q( + indoc! {r#" + namespace Test { + import Std.Measurement.*; + @EntryPoint() + operation Main() : Int { + use q = Qubit(); + for i in 0..2 { + mutable j = 0; + while j < 2 { + if M(q) == One { + return i * 10 + j; + } + j += 1; + } + } + 0 + } + } + "#}, + &expect![[r#" + // namespace Test + operation Main() : Int { + body { + mutable __has_returned : Bool = false; + mutable __ret_val : Int = 0; + let q : Qubit = __quantum__rt__qubit_allocate(); + { + let + @range_id_53 : Range = 0..2; + mutable + @index_id_56 : Int = + @range_id_53::Start; + let + @step_id_61 : Int = + @range_id_53::Step; + let + @end_id_66 : Int = + @range_id_53::End; + while not __has_returned and + @step_id_61 > 0 and + @index_id_56 <= + @end_id_66 or + @step_id_61 < 0 and + @index_id_56 >= + @end_id_66 { + let i : Int = + @index_id_56; + mutable j : Int = 0; + while not __has_returned and j < 2 { + if M(q) == One { + { + let + @generated_ident_101 : Int = i * 10 + j; + __quantum__rt__qubit_release(q); + { + __ret_val = + @generated_ident_101; + __has_returned = true; + }; + }; + } + + if not __has_returned { + j += 1; + }; + } + + if not __has_returned { + @index_id_56 += + @step_id_61; + }; + } + + } + + let + @generated_ident_113 : Int = { + 0 + }; + if not __has_returned { + __quantum__rt__qubit_release(q); + }; + let __trailing_result : Int = + @generated_ident_113; + if __has_returned __ret_val else __trailing_result + } + } + function Length(a : Pauli[]) : Int { + body intrinsic; + } + function Length(a : Qubit[]) : Int { + body intrinsic; + } + // entry + Main() + "#]], + ); +} + +#[test] +fn three_level_if_while_for_return_deep() { + // if / while / for -> return inside the for + check_no_returns_q( + indoc! {r#" + namespace Test { + import Std.Measurement.*; + @EntryPoint() + operation Main() : Int { + use q = Qubit(); + if M(q) == One { + mutable i = 0; + while i < 3 { + for j in 0..2 { + if M(q) == Zero { + return i + j; + } + } + i += 1; + } + } + -1 + } + } + "#}, + &expect![[r#" + // namespace Test + operation Main() : Int { + body { + mutable __has_returned : Bool = false; + mutable __ret_val : Int = 0; + let q : Qubit = __quantum__rt__qubit_allocate(); + if M(q) == One { + mutable i : Int = 0; + while not __has_returned and i < 3 { + { + let + @range_id_61 : Range = 0..2; + mutable + @index_id_64 : Int = + @range_id_61::Start; + let + @step_id_69 : Int = + @range_id_61::Step; + let + @end_id_74 : Int = + @range_id_61::End; + while not __has_returned and + @step_id_69 > 0 and + @index_id_64 <= + @end_id_74 or + @step_id_69 < 0 and + @index_id_64 >= + @end_id_74 { + let j : Int = + @index_id_64; + if M(q) == Zero { + { + let + @generated_ident_109 : Int = i + j; + __quantum__rt__qubit_release(q); + { + __ret_val = + @generated_ident_109; + __has_returned = true; + }; + }; + } + + if not __has_returned { + @index_id_64 += + @step_id_69; + }; + } + + } + + if not __has_returned { + i += 1; + }; + } + + } + + let + @generated_ident_121 : Int = if not __has_returned { + -1 + } else { + 0 + }; + if not __has_returned { + __quantum__rt__qubit_release(q); + }; + let __trailing_result : Int = + @generated_ident_121; + if __has_returned __ret_val else __trailing_result + } + } + function Length(a : Pauli[]) : Int { + body intrinsic; + } + function Length(a : Qubit[]) : Int { + body intrinsic; + } + // entry + Main() + "#]], + ); +} diff --git a/source/compiler/qsc_fir_transforms/src/return_unify/normalize/tests/three_level_mixed.rs b/source/compiler/qsc_fir_transforms/src/return_unify/normalize/tests/three_level_mixed.rs new file mode 100644 index 0000000000..8778a2b587 --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/return_unify/normalize/tests/three_level_mixed.rs @@ -0,0 +1,475 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! Three-level nesting tests: mixed constructs, blocks, qubit scopes, +//! multi-level returns, and compound-position returns at depth. + +use super::*; + +#[test] +fn three_level_block_block_if_returns_at_each_level() { + // nested bare blocks with returns sprinkled at every level + check_no_returns_q( + indoc! {r#" + namespace Test { + import Std.Measurement.*; + @EntryPoint() + operation Main() : Int { + use q = Qubit(); + { + if M(q) == One { + return 1; + } + { + if M(q) == Zero { + return 2; + } + { + if M(q) == One { + return 3; + } + 4 + } + } + } + } + } + "#}, + &expect![[r#" + // namespace Test + operation Main() : Int { + body { + mutable __has_returned : Bool = false; + mutable __ret_val : Int = 0; + let q : Qubit = __quantum__rt__qubit_allocate(); + let + @generated_ident_101 : Int = { + if M(q) == One { + { + let + @generated_ident_65 : Int = 1; + __quantum__rt__qubit_release(q); + { + __ret_val = + @generated_ident_65; + __has_returned = true; + }; + }; + } + + { + if M(q) == Zero { + { + let + @generated_ident_77 : Int = 2; + __quantum__rt__qubit_release(q); + { + __ret_val = + @generated_ident_77; + __has_returned = true; + }; + }; + } + + { + if M(q) == One { + { + let + @generated_ident_89 : Int = 3; + __quantum__rt__qubit_release(q); + { + __ret_val = + @generated_ident_89; + __has_returned = true; + }; + }; + } + + 4 + } + + } + + }; + if not __has_returned { + __quantum__rt__qubit_release(q); + }; + let __trailing_result : Int = + @generated_ident_101; + if __has_returned __ret_val else __trailing_result + } + } + function Length(a : Pauli[]) : Int { + body intrinsic; + } + function Length(a : Qubit[]) : Int { + body intrinsic; + } + // entry + Main() + "#]], + ); +} + +#[test] +fn three_level_qubit_scopes_with_deep_return() { + // Three nested qubit allocation scopes; return deep inside the innermost + // scope. The strategy pass must preserve the release order of all three + // qubit scopes on the return path. + check_no_returns_q( + indoc! {r#" + namespace Test { + import Std.Measurement.*; + @EntryPoint() + operation Main() : Int { + use q0 = Qubit(); + if M(q0) == One { + use q1 = Qubit(); + if M(q1) == One { + use q2 = Qubit(); + if M(q2) == One { + return 42; + } + } + } + 0 + } + } + "#}, + &expect![[r#" + // namespace Test + operation Main() : Int { + body { + mutable __has_returned : Bool = false; + mutable __ret_val : Int = 0; + let q0 : Qubit = __quantum__rt__qubit_allocate(); + if M(q0) == One { + let q1 : Qubit = __quantum__rt__qubit_allocate(); + let + @generated_ident_97 : Unit = if M(q1) == One { + let q2 : Qubit = __quantum__rt__qubit_allocate(); + let + @generated_ident_88 : Unit = if M(q2) == One { + { + let + @generated_ident_68 : Int = 42; + __quantum__rt__qubit_release(q2); + __quantum__rt__qubit_release(q1); + __quantum__rt__qubit_release(q0); + { + __ret_val = + @generated_ident_68; + __has_returned = true; + }; + }; + }; + if not __has_returned { + __quantum__rt__qubit_release(q2); + }; + @generated_ident_88 + }; + if not __has_returned { + __quantum__rt__qubit_release(q1); + }; + @generated_ident_97 + } + + let + @generated_ident_106 : Int = { + 0 + }; + if not __has_returned { + __quantum__rt__qubit_release(q0); + }; + let __trailing_result : Int = + @generated_ident_106; + if __has_returned __ret_val else __trailing_result + } + } + function Length(a : Pauli[]) : Int { + body intrinsic; + } + function Length(a : Qubit[]) : Int { + body intrinsic; + } + // entry + Main() + "#]], + ); +} + +#[test] +fn three_level_nested_returns_at_every_level() { + // Each level has its own return on its own branch; the strategy pass + // must flatten all three into a single post-unification control flow. + check_no_returns_q( + indoc! {r#" + namespace Test { + import Std.Measurement.*; + @EntryPoint() + operation Main() : Int { + use q = Qubit(); + if M(q) == One { + return 1; + } + if M(q) == Zero { + if M(q) == One { + return 2; + } + if M(q) == Zero { + if M(q) == One { + return 3; + } + } + } + 0 + } + } + "#}, + &expect![[r#" + // namespace Test + operation Main() : Int { + body { + mutable __has_returned : Bool = false; + mutable __ret_val : Int = 0; + let q : Qubit = __quantum__rt__qubit_allocate(); + if M(q) == One { + { + let + @generated_ident_74 : Int = 1; + __quantum__rt__qubit_release(q); + { + __ret_val = + @generated_ident_74; + __has_returned = true; + }; + }; + } + + if not __has_returned { + if M(q) == Zero { + if M(q) == One { + { + let + @generated_ident_86 : Int = 2; + __quantum__rt__qubit_release(q); + { + __ret_val = + @generated_ident_86; + __has_returned = true; + }; + }; + } + + if M(q) == Zero { + if M(q) == One { + { + let + @generated_ident_98 : Int = 3; + __quantum__rt__qubit_release(q); + { + __ret_val = + @generated_ident_98; + __has_returned = true; + }; + }; + } + + } + + } + + }; + let + @generated_ident_110 : Int = { + 0 + }; + if not __has_returned { + __quantum__rt__qubit_release(q); + }; + let __trailing_result : Int = + @generated_ident_110; + if __has_returned __ret_val else __trailing_result + } + } + function Length(a : Pauli[]) : Int { + body intrinsic; + } + function Length(a : Qubit[]) : Int { + body intrinsic; + } + // entry + Main() + "#]], + ); +} + +#[test] +fn three_level_hoist_return_in_call_arg_deep() { + // Compound-position return three constructs deep: the inner `Return` + // sits inside a `Call` argument inside an `if` inside a `while` inside + // a `for`. Exercises the hoist pre-pass driving the strategy pass at + // depth. + check_no_returns_q( + indoc! {r#" + namespace Test { + function Add(a : Int, b : Int) : Int { a + b } + @EntryPoint() + operation Main() : Int { + mutable total = 0; + for i in 0..1 { + mutable j = 0; + while j < 2 { + if i == j { + total = Add(total, (return i * 100 + j)); + } + j += 1; + } + } + total + } + } + "#}, + &expect![[r#" + // namespace Test + function Add(a : Int, b : Int) : Int { + body { + a + b + } + } + operation Main() : Int { + body { + mutable __has_returned : Bool = false; + mutable __ret_val : Int = 0; + mutable total : Int = 0; + { + let + @range_id_70 : Range = 0..1; + mutable + @index_id_73 : Int = + @range_id_70::Start; + let + @step_id_78 : Int = + @range_id_70::Step; + let + @end_id_83 : Int = + @range_id_70::End; + while not __has_returned and + @step_id_78 > 0 and + @index_id_73 <= + @end_id_83 or + @step_id_78 < 0 and + @index_id_73 >= + @end_id_83 { + let i : Int = + @index_id_73; + mutable j : Int = 0; + while not __has_returned and j < 2 { + if i == j { + let _ : Int = total; + let _ : ((Int, Int) -> Int) = Add; + let _ : Int = total; + { + __ret_val = i * 100 + j; + __has_returned = true; + }; + } + + if not __has_returned { + j += 1; + }; + } + + if not __has_returned { + @index_id_73 += + @step_id_78; + }; + } + + } + + let __trailing_result : Int = total; + if __has_returned __ret_val else __trailing_result + } + } + // entry + Main() + "#]], + ); +} + +#[test] +fn three_level_outer_return_wraps_three_deep_block() { + // An outer bare `return` wrapping three levels of block-bearing + // constructs whose leaf holds a statement-level return. Exercises the + // `bind_inner_and_return` path across multiple nesting levels. + check_no_returns_q( + indoc! {r#" + namespace Test { + import Std.Measurement.*; + @EntryPoint() + operation Main() : Int { + use q = Qubit(); + return if M(q) == One { + if M(q) == Zero { + if M(q) == One { + return 1; + } + 2 + } else { + 3 + } + } else { + 4 + }; + } + } + "#}, + &expect![[r#" + // namespace Test + operation Main() : Int { + body { + let q : Qubit = __quantum__rt__qubit_allocate(); + { + if M(q) == One { + if M(q) == Zero { + if M(q) == One { + { + let + @generated_ident_60 : Int = 1; + __quantum__rt__qubit_release(q); + @generated_ident_60 + } + + } else { + 2 + } + + } else { + 3 + } + + } else { + let + @generated_ident_59 : Int = { + 4 + }; + __quantum__rt__qubit_release(q); + @generated_ident_59 + } + + } + + } + } + function Length(a : Pauli[]) : Int { + body intrinsic; + } + function Length(a : Qubit[]) : Int { + body intrinsic; + } + // entry + Main() + "#]], + ); +} diff --git a/source/compiler/qsc_fir_transforms/src/return_unify/semantic_equivalence_tests.rs b/source/compiler/qsc_fir_transforms/src/return_unify/semantic_equivalence_tests.rs new file mode 100644 index 0000000000..51dfd64381 --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/return_unify/semantic_equivalence_tests.rs @@ -0,0 +1,182 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use crate::test_utils::check_semantic_equivalence; +use indoc::formatdoc; +use proptest::prelude::*; + +/// Generates syntactically valid Q# programs with return statements at +/// various positions covering all `return_unify` dispatch categories +/// (structured, flag, no-return). Each program wraps one of 12 template +/// patterns in a `namespace Test { function Main() : Int { ... } }` shell. +#[allow(clippy::too_many_lines)] +fn return_pattern_strategy() -> impl Strategy { + let cmp = || 0..10i64; + let val = || 0..100i64; + let bound = || 1..6i64; + let idx = || 0..5i64; + + prop_oneof![ + // 1. No-return baseline: pure if-else expression. + (cmp(), cmp(), val(), val()).prop_map(|(a, b, c, d)| formatdoc! {" + namespace Test {{ + function Main() : Int {{ + if {a} > {b} {{ {c} }} else {{ {d} }} + }} + }} + "}), + // 2. Single guard clause. + (cmp(), cmp(), val(), val()).prop_map(|(a, b, c, d)| formatdoc! {" + namespace Test {{ + function Main() : Int {{ + if {a} > {b} {{ return {c}; }} + {d} + }} + }} + "}), + // 3. Both branches return. + (cmp(), cmp(), val(), val()).prop_map(|(a, b, c, d)| formatdoc! {" + namespace Test {{ + function Main() : Int {{ + if {a} > {b} {{ return {c}; }} else {{ return {d}; }} + }} + }} + "}), + // 4. Two guard clauses with fallthrough. + (cmp(), cmp(), cmp(), cmp(), val(), val(), val()).prop_map( + |(a, b, c, d, e, f, g)| formatdoc! {" + namespace Test {{ + function Main() : Int {{ + if {a} > {b} {{ return {e}; }} + if {c} > {d} {{ return {f}; }} + {g} + }} + }} + "} + ), + // 5. While with early return. + (bound(), idx(), val(), val()).prop_map(|(n, t, v, d)| formatdoc! {" + namespace Test {{ + function Main() : Int {{ + mutable x = 0; + while x < {n} {{ + if x == {t} {{ return {v}; }} + x += 1; + }} + {d} + }} + }} + "}), + // 6. For loop with early return. + (bound(), idx(), val(), val()).prop_map(|(n, t, v, d)| formatdoc! {" + namespace Test {{ + function Main() : Int {{ + for i in 0..{n} {{ + if i == {t} {{ return {v}; }} + }} + {d} + }} + }} + "}), + // 7. Nested if with return. + (cmp(), cmp(), cmp(), cmp(), val(), val(), val()).prop_map( + |(a, b, c, d, e, f, g)| formatdoc! {" + namespace Test {{ + function Main() : Int {{ + if {a} > {b} {{ + if {c} > {d} {{ return {e}; }} + {f} + }} else {{ + {g} + }} + }} + }} + "} + ), + // 8. Block expression with return. + (cmp(), cmp(), val(), val(), val()).prop_map(|(a, b, c, d, e)| formatdoc! {" + namespace Test {{ + function Main() : Int {{ + let x = {{ + if {a} > {b} {{ return {c}; }} + {d} + }}; + x + {e} + }} + }} + "}), + // 9. Return in else branch only. + (cmp(), cmp(), val(), val()).prop_map(|(a, b, c, d)| formatdoc! {" + namespace Test {{ + function Main() : Int {{ + if {a} > {b} {{ {c} }} else {{ return {d}; }} + }} + }} + "}), + // 10. Multiple returns with mutable computation. + (cmp(), cmp(), cmp(), cmp(), val(), val(), val(), val()).prop_map( + |(a, b, c, d, e, f, g, h)| formatdoc! {" + namespace Test {{ + function Main() : Int {{ + mutable result = 0; + if {a} > {b} {{ return {e}; }} + result = {f}; + if {c} > {d} {{ return {g}; }} + result + {h} + }} + }} + "} + ), + // 11. Triple nested if-return. + ( + cmp(), + cmp(), + cmp(), + cmp(), + cmp(), + cmp(), + val(), + val(), + val(), + val() + ) + .prop_map(|(a, b, c, d, e, f, g, h, i, j)| formatdoc! {" + namespace Test {{ + function Main() : Int {{ + if {a} > {b} {{ + if {c} > {d} {{ + if {e} > {f} {{ return {g}; }} + return {h}; + }} + {i} + }} else {{ + return {j}; + }} + }} + }} + "}), + // 12. While with accumulator and conditional return. + (bound(), idx()).prop_map(|(n, t)| formatdoc! {" + namespace Test {{ + function Main() : Int {{ + mutable acc = 0; + mutable i = 0; + while i < {n} {{ + if i > {t} {{ return acc; }} + acc = acc + i; + i += 1; + }} + acc + }} + }} + "}), + ] +} + +proptest! { + #![proptest_config(ProptestConfig::with_cases(100))] + #[test] + fn differential_return_unify(source in return_pattern_strategy()) { + check_semantic_equivalence(&source); + } +} diff --git a/source/compiler/qsc_fir_transforms/src/return_unify/tests.rs b/source/compiler/qsc_fir_transforms/src/return_unify/tests.rs new file mode 100644 index 0000000000..610ad81c2c --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/return_unify/tests.rs @@ -0,0 +1,826 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#![allow(clippy::needless_raw_string_hashes)] + +//! Tests for the return unification pass. + +mod contracts_and_errors; +mod flag_strategy; +mod idempotency; +mod qubit_release; +mod regressions; +mod semantic; +mod structured_strategy; +mod type_preservation; + +use expect_test::{Expect, expect}; +use rustc_hash::FxHashSet; + +use crate::reachability::collect_reachable_from_entry; +use crate::test_utils::{ + PipelineStage, compile_and_run_pipeline_to, compile_and_run_pipeline_to_with_errors, + compile_to_fir, +}; +use crate::walk_utils::{for_each_expr, for_each_expr_in_callable_impl}; +use indoc::indoc; +use qsc_data_structures::{ + language_features::LanguageFeatures, source::SourceMap, target::TargetCapabilityFlags, +}; +use qsc_fir::assigner::Assigner; +use qsc_fir::fir::{ + BinOp, BlockId, CallableImpl, Expr, ExprId, ExprKind, ItemKind, Lit, LocalVarId, Package, + PackageId, PackageLookup, PackageStore, Pat, PatKind, Res, StmtId, StmtKind, StoreItemId, UnOp, +}; +use qsc_fir::ty::{Prim, Ty}; + +pub(crate) type ReleaseCallableSet = FxHashSet; + +/// Collects the set of callables that release qubit allocations. +pub(crate) fn collect_release_callables(store: &PackageStore) -> ReleaseCallableSet { + let mut release_callables = FxHashSet::default(); + for (package_id, package) in store { + for (item_id, item) in &package.items { + let ItemKind::Callable(decl) = &item.kind else { + continue; + }; + if matches!( + decl.name.name.as_ref(), + "__quantum__rt__qubit_release" | "ReleaseQubitArray" + ) { + release_callables.insert(StoreItemId { + package: package_id, + item: item_id, + }); + } + } + } + release_callables +} + +/// Test-only reimplementation of the removed `is_release_call` helper. +fn is_release_call_test( + package: &Package, + stmt_id: StmtId, + release_set: &ReleaseCallableSet, +) -> bool { + let stmt = package.get_stmt(stmt_id); + let StmtKind::Semi(expr_id) = &stmt.kind else { + return false; + }; + let expr = package.get_expr(*expr_id); + let ExprKind::Call(callee_id, _) = &expr.kind else { + return false; + }; + let callee = package.get_expr(*callee_id); + let ExprKind::Var(Res::Item(item_id), _) = &callee.kind else { + return false; + }; + release_set.contains(&StoreItemId { + package: item_id.package, + item: item_id.item, + }) +} + +struct NoHoistReturnUnifyResult { + store: PackageStore, + pkg_id: PackageId, + before: String, + after: String, +} + +impl NoHoistReturnUnifyResult { + fn before_after(&self) -> String { + format!( + "// before direct no-hoist return_unify\n{}\n// post direct no-hoist return_unify\n{}", + self.before, self.after + ) + } +} + +pub(crate) fn assert_no_reachable_returns(store: &PackageStore, pkg_id: PackageId) { + let package = store.get(pkg_id); + let reachable = collect_reachable_from_entry(store, pkg_id); + + for store_id in &reachable { + if store_id.package != pkg_id { + continue; + } + let item = package.get_item(store_id.item); + if let ItemKind::Callable(decl) = &item.kind { + for_each_expr_in_callable_impl(package, &decl.implementation, &mut |_id, expr| { + assert!( + !matches!(expr.kind, ExprKind::Return(_)), + "Return node found in callable '{}' after direct no-hoist return unification", + decl.name.name + ); + }); + } + } +} + +fn compile_no_hoist_return_unified(source: &str) -> NoHoistReturnUnifyResult { + let (mut store, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::Mono); + let before = crate::pretty::write_package_qsharp(&store, pkg_id); + + let mut assigner = Assigner::from_package(store.get(pkg_id)); + let errors = super::unify_returns(&mut store, pkg_id, &mut assigner); + assert!( + errors.is_empty(), + "direct no-hoist return_unify produced errors: {errors:?}\n// before direct no-hoist return_unify\n{before}" + ); + assert_no_reachable_returns(&store, pkg_id); + + let after = crate::pretty::write_package_qsharp(&store, pkg_id); + NoHoistReturnUnifyResult { + store, + pkg_id, + before, + after, + } +} + +fn release_store_id(package: &Package, expr: &Expr) -> Option { + let ExprKind::Call(callee_id, _) = &expr.kind else { + return None; + }; + let callee = package.get_expr(*callee_id); + let ExprKind::Var(Res::Item(item_id), _) = &callee.kind else { + return None; + }; + Some(StoreItemId { + package: item_id.package, + item: item_id.item, + }) +} + +fn expr_contains_release_call( + package: &Package, + expr_id: ExprId, + release_set: &ReleaseCallableSet, +) -> bool { + let mut has_release = false; + for_each_expr(package, expr_id, &mut |_id, expr| { + has_release |= release_store_id(package, expr).is_some_and(|id| release_set.contains(&id)); + }); + has_release +} + +fn stmt_contains_path_local_release_value( + package: &Package, + stmt_id: StmtId, + release_set: &ReleaseCallableSet, +) -> bool { + let stmt = package.get_stmt(stmt_id); + match stmt.kind { + StmtKind::Local(_, _, init_expr_id) | StmtKind::Expr(init_expr_id) => { + expr_contains_release_call(package, init_expr_id, release_set) + } + StmtKind::Semi(expr_id) => { + release_store_id(package, package.get_expr(expr_id)).is_none() + && expr_contains_release_call(package, expr_id, release_set) + } + StmtKind::Item(_) => false, + } +} + +fn assert_path_local_releases_without_unconditional_suffix( + result: &NoHoistReturnUnifyResult, + callable_name: &str, +) { + let package = result.store.get(result.pkg_id); + let release_set = collect_release_callables(&result.store); + let body_block_id = find_body_block_id(package, callable_name); + let body_block = package.get_block(body_block_id); + + let Some(path_local_release_index) = body_block.stmts.iter().position(|&stmt_id| { + stmt_contains_path_local_release_value(package, stmt_id, &release_set) + }) else { + panic!( + "{callable_name} should preserve at least one path-local release after direct no-hoist return_unify\n{}", + result.before_after() + ); + }; + + let release_suffix_after_path_local = body_block.stmts[path_local_release_index + 1..] + .iter() + .any(|&stmt_id| is_release_call_test(package, stmt_id, &release_set)); + + assert!( + !release_suffix_after_path_local, + "{callable_name} should not run an unconditional release suffix after a value path that already contains path-local releases\n{}", + result.before_after() + ); +} + +fn expr_contains_guarded_release_call( + package: &Package, + expr_id: ExprId, + release_set: &ReleaseCallableSet, + has_returned_var_id: LocalVarId, +) -> bool { + let mut found_guarded_release = false; + for_each_expr(package, expr_id, &mut |_id, expr| { + let ExprKind::If(cond_expr_id, then_expr_id, None) = &expr.kind else { + return; + }; + + found_guarded_release |= is_not_flag_expr(package, *cond_expr_id, has_returned_var_id) + && expr_contains_release_call(package, *then_expr_id, release_set); + }); + found_guarded_release +} + +fn assert_guarded_release_continuation(result: &NoHoistReturnUnifyResult, callable_name: &str) { + let package = result.store.get(result.pkg_id); + let release_set = collect_release_callables(&result.store); + let (flag_pat, _) = find_local_init(package, callable_name, "__has_returned"); + let has_returned_var_id = local_var_id_from_named_pat(flag_pat, "__has_returned"); + let decl = find_callable_decl(package, callable_name); + + let mut found_guarded_release = false; + for_each_expr_in_callable_impl(package, &decl.implementation, &mut |expr_id, _expr| { + found_guarded_release |= + expr_contains_guarded_release_call(package, expr_id, &release_set, has_returned_var_id); + }); + + assert!( + found_guarded_release, + "{callable_name} should guard release continuations with not __has_returned after direct no-hoist return_unify\n{}", + result.before_after() + ); +} + +fn eval_qsharp_no_hoist_return_unified(source: &str) -> Result { + let NoHoistReturnUnifyResult { + mut store, pkg_id, .. + } = compile_no_hoist_return_unified(source); + crate::exec_graph_rebuild::rebuild_exec_graphs(&mut store, pkg_id, &[]); + try_eval_fir_entry(&store, pkg_id) +} + +fn check_no_hoist_semantic_equivalence(source: &str) { + let expected = eval_qsharp_original(source); + let actual = eval_qsharp_no_hoist_return_unified(source); + + match (&expected, &actual) { + (Ok(exp_val), Ok(act_val)) => { + assert_eq!( + exp_val, act_val, + "direct no-hoist return_unify semantic equivalence violated: original returned {exp_val}, transformed returned {act_val}" + ); + } + (Err(exp_err), Err(act_err)) => { + assert_eq!( + exp_err, act_err, + "direct no-hoist return_unify semantic equivalence violated: original failed with {exp_err}, transformed failed with {act_err}" + ); + } + (Ok(exp_val), Err(err)) => { + panic!( + "original succeeded with {exp_val} but direct no-hoist return_unify failed: {err}" + ); + } + (Err(err), Ok(act_val)) => { + panic!( + "original failed with {err} but direct no-hoist return_unify succeeded with {act_val}" + ); + } + } +} + +/// Compiles source through mono + `return_unify` and asserts no Return nodes +/// remain in any reachable callable. Returns a summary string of the body +/// structure for snapshot testing. +pub(crate) fn compile_return_unified( + source: &str, +) -> (qsc_fir::fir::PackageStore, qsc_fir::fir::PackageId) { + let (store, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::ReturnUnify); + assert_no_reachable_returns(&store, pkg_id); + + (store, pkg_id) +} + +fn describe_pat(package: &Package, pat_id: qsc_fir::fir::PatId) -> String { + let pat = package.get_pat(pat_id); + match &pat.kind { + PatKind::Bind(ident) => format!("{}: {}", ident.name, pat.ty), + PatKind::Tuple(items) => format!( + "({})", + items + .iter() + .map(|&item| describe_pat(package, item)) + .collect::>() + .join(", ") + ), + PatKind::Discard => format!("_: {}", pat.ty), + } +} + +fn push_spec_summary( + package: &Package, + label: &str, + spec: &qsc_fir::fir::SpecDecl, + lines: &mut Vec, +) { + let block = package.get_block(spec.block); + lines.push(format!(" {label}: block_ty={}", block.ty)); + for (index, stmt_id) in block.stmts.iter().enumerate() { + let stmt = package.get_stmt(*stmt_id); + let line = match &stmt.kind { + StmtKind::Expr(expr_id) => { + format!( + " [{index}] Expr {}", + describe_expr(package, *expr_id) + ) + } + StmtKind::Semi(expr_id) => { + format!( + " [{index}] Semi {}", + describe_expr(package, *expr_id) + ) + } + StmtKind::Local(mutability, pat_id, expr_id) => format!( + " [{index}] Local({mutability:?}, {}): {}", + describe_pat(package, *pat_id), + describe_expr(package, *expr_id) + ), + StmtKind::Item(local_item_id) => format!(" [{index}] Item {local_item_id}"), + }; + lines.push(line); + } +} + +fn summarize_callable(package: &Package, callable_name: &str) -> String { + let decl = package + .items + .values() + .find_map(|item| match &item.kind { + ItemKind::Callable(decl) if decl.name.name.as_ref() == callable_name => Some(decl), + _ => None, + }) + .unwrap_or_else(|| panic!("callable '{callable_name}' not found")); + + let mut lines = vec![format!( + "callable {}: input_ty={}, output_ty={}", + decl.name.name, + package.get_pat(decl.input).ty, + decl.output + )]; + + match &decl.implementation { + CallableImpl::Intrinsic => lines.push(" intrinsic".to_string()), + CallableImpl::Spec(spec_impl) => { + push_spec_summary(package, "body", &spec_impl.body, &mut lines); + for (label, spec) in [ + ("adj", spec_impl.adj.as_ref()), + ("ctl", spec_impl.ctl.as_ref()), + ("ctl_adj", spec_impl.ctl_adj.as_ref()), + ] { + if let Some(spec) = spec { + push_spec_summary(package, label, spec, &mut lines); + } + } + } + CallableImpl::SimulatableIntrinsic(spec) => { + push_spec_summary(package, "simulatable", spec, &mut lines); + } + } + + lines.join("\n") +} + +/// Check the structure of callables after return unification. +pub(crate) fn check_structure(source: &str, callable_names: &[&str], expect: &Expect) { + let (store, pkg_id) = compile_return_unified(source); + let package = store.get(pkg_id); + let summary = callable_names + .iter() + .map(|callable_name| summarize_callable(package, callable_name)) + .collect::>() + .join("\n"); + expect.assert_eq(&summary); +} + +/// Compile, run the pipeline through `ReturnUnify`, assert no +/// `ExprKind::Return` survives in any reachable callable, and pin the +/// resulting FIR as formatted Q# via `expect_test`. +/// +/// The `expect` snapshot is generated by +/// [`crate::pretty::write_package_qsharp`]. +pub(crate) fn check_no_returns_q(source: &str, expect: &Expect) { + let (store, pkg_id) = compile_return_unified(source); + let rendered = crate::pretty::write_package_qsharp(&store, pkg_id); + expect.assert_eq(&rendered); +} + +fn check_pre_fir_transforms_to_return_unify_q(source: &str, expect: &Expect) { + let (before_store, before_pkg_id) = compile_to_fir(source); + let before = crate::pretty::write_package_qsharp(&before_store, before_pkg_id); + + let (after_store, after_pkg_id) = compile_return_unified(source); + let after = crate::pretty::write_package_qsharp(&after_store, after_pkg_id); + + expect.assert_eq(&format!( + "// before fir transforms\n{before}\n// post return_unify\n{after}" + )); +} + +fn find_local_init<'a>( + package: &'a Package, + callable_name: &str, + local_name: &str, +) -> (&'a Pat, &'a Expr) { + for item in package.items.values() { + if let ItemKind::Callable(decl) = &item.kind + && decl.name.name.as_ref() == callable_name + && let CallableImpl::Spec(spec) = &decl.implementation + { + let block = package.get_block(spec.body.block); + for &stmt_id in &block.stmts { + let stmt = package.get_stmt(stmt_id); + let StmtKind::Local(_, pat_id, init_expr_id) = &stmt.kind else { + continue; + }; + let pat = package.get_pat(*pat_id); + if let PatKind::Bind(ident) = &pat.kind + && ident.name.as_ref() == local_name + { + return (pat, package.get_expr(*init_expr_id)); + } + } + } + } + + panic!("local '{local_name}' not found in callable '{callable_name}'"); +} + +fn find_callable_decl<'a>( + package: &'a Package, + callable_name: &str, +) -> &'a qsc_fir::fir::CallableDecl { + package + .items + .values() + .find_map(|item| match &item.kind { + ItemKind::Callable(decl) if decl.name.name.as_ref() == callable_name => Some(decl), + _ => None, + }) + .unwrap_or_else(|| panic!("callable '{callable_name}' not found")) +} + +fn find_body_block_id(package: &Package, callable_name: &str) -> BlockId { + let decl = find_callable_decl(package, callable_name); + let CallableImpl::Spec(spec_impl) = &decl.implementation else { + panic!("callable '{callable_name}' must have a body spec") + }; + spec_impl.body.block +} + +fn local_var_id_from_named_pat(pat: &Pat, local_name: &str) -> LocalVarId { + let PatKind::Bind(ident) = &pat.kind else { + panic!("local '{local_name}' should bind a single local var") + }; + ident.id +} + +fn expr_reads_local(package: &Package, expr_id: ExprId, expected_local: LocalVarId) -> bool { + matches!( + &package.get_expr(expr_id).kind, + ExprKind::Var(Res::Local(local_id), _) if *local_id == expected_local + ) +} + +fn is_not_flag_expr(package: &Package, expr_id: ExprId, has_returned_var_id: LocalVarId) -> bool { + let ExprKind::UnOp(UnOp::NotL, inner_expr_id) = &package.get_expr(expr_id).kind else { + return false; + }; + expr_reads_local(package, *inner_expr_id, has_returned_var_id) +} + +fn assert_while_condition_guarded_by_not_flag( + package: &Package, + cond_expr_id: ExprId, + has_returned_var_id: LocalVarId, +) { + let ExprKind::BinOp(BinOp::AndL, lhs_expr_id, _rhs_expr_id) = + &package.get_expr(cond_expr_id).kind + else { + panic!("while condition should be rewritten to not __has_returned and cond") + }; + + assert!( + is_not_flag_expr(package, *lhs_expr_id, has_returned_var_id), + "while condition LHS should be not __has_returned" + ); +} + +fn assignment_target_local(package: &Package, expr_id: ExprId) -> Option { + let ExprKind::Assign(lhs_expr_id, _rhs_expr_id) = &package.get_expr(expr_id).kind else { + return None; + }; + let ExprKind::Var(Res::Local(local_id), _) = &package.get_expr(*lhs_expr_id).kind else { + return None; + }; + Some(*local_id) +} + +fn assert_local_initializer_then_assign_order( + package: &Package, + init_expr_id: ExprId, + ret_val_var_id: LocalVarId, + has_returned_var_id: LocalVarId, +) -> bool { + let ExprKind::If(_cond_expr_id, _then_expr_id, _else_expr_id) = + &package.get_expr(init_expr_id).kind + else { + panic!("expected Local initializer to remain an if-expression") + }; + + let mut writes = Vec::new(); + for_each_expr(package, init_expr_id, &mut |_expr_id, expr| { + let ExprKind::Assign(lhs_expr_id, _rhs_expr_id) = &expr.kind else { + return; + }; + if let Some(target_local) = assignment_target_local(package, *lhs_expr_id) { + writes.push(target_local); + } + }); + + let Some(ret_write_idx) = writes.iter().position(|local| *local == ret_val_var_id) else { + return false; + }; + let Some(flag_write_idx) = writes + .iter() + .position(|local| *local == has_returned_var_id) + else { + return false; + }; + + assert!( + ret_write_idx < flag_write_idx, + "rewritten return path must assign __ret_val before setting __has_returned" + ); + + true +} + +fn assert_callable_assign_order( + package: &Package, + callable_name: &str, + ret_val_var_id: LocalVarId, + has_returned_var_id: LocalVarId, +) { + let decl = find_callable_decl(package, callable_name); + let mut writes = Vec::new(); + for_each_expr_in_callable_impl(package, &decl.implementation, &mut |expr_id, _expr| { + if let Some(target_local) = assignment_target_local(package, expr_id) { + writes.push(target_local); + } + }); + + let ret_write_idx = writes + .iter() + .position(|local| *local == ret_val_var_id) + .expect("rewritten return path should assign __ret_val"); + let flag_write_idx = writes + .iter() + .position(|local| *local == has_returned_var_id) + .expect("rewritten return path should assign __has_returned"); + + assert!( + ret_write_idx < flag_write_idx, + "rewritten return path must assign __ret_val before setting __has_returned" + ); +} + +fn expr_calls_named_callable( + store: &PackageStore, + package: &Package, + expr_id: ExprId, + callable_name: &str, +) -> bool { + let ExprKind::Call(callee_expr_id, _) = &package.get_expr(expr_id).kind else { + return false; + }; + let ExprKind::Var(Res::Item(item_id), _) = &package.get_expr(*callee_expr_id).kind else { + return false; + }; + + let callee_package = store.get(item_id.package); + matches!( + &callee_package.get_item(item_id.item).kind, + ItemKind::Callable(decl) if decl.name.name.as_ref() == callable_name + ) +} + +fn stmt_calls_named_callable( + store: &PackageStore, + package: &Package, + stmt_id: StmtId, + callable_name: &str, +) -> bool { + let expr_id = match &package.get_stmt(stmt_id).kind { + StmtKind::Expr(expr_id) | StmtKind::Semi(expr_id) => *expr_id, + StmtKind::Local(_, _, _) | StmtKind::Item(_) => return false, + }; + + expr_calls_named_callable(store, package, expr_id, callable_name) +} + +fn expr_tree_calls_named_callable( + store: &PackageStore, + package: &Package, + expr_id: ExprId, + callable_name: &str, +) -> bool { + let mut found = false; + for_each_expr(package, expr_id, &mut |nested_expr_id, _expr| { + found |= expr_calls_named_callable(store, package, nested_expr_id, callable_name); + }); + found +} + +fn stmt_tree_calls_named_callable( + store: &PackageStore, + package: &Package, + stmt_id: StmtId, + callable_name: &str, +) -> bool { + let expr_id = match &package.get_stmt(stmt_id).kind { + StmtKind::Expr(expr_id) | StmtKind::Semi(expr_id) | StmtKind::Local(_, _, expr_id) => { + *expr_id + } + StmtKind::Item(_) => return false, + }; + + expr_tree_calls_named_callable(store, package, expr_id, callable_name) +} + +/// Short description of an expression for snapshot output. +fn describe_expr(package: &qsc_fir::fir::Package, expr_id: qsc_fir::fir::ExprId) -> String { + let expr = package.get_expr(expr_id); + match &expr.kind { + ExprKind::If(cond, then_e, else_opt) => { + let else_str = match else_opt { + Some(e) => format!(", else={}", describe_expr(package, *e)), + None => String::new(), + }; + format!( + "If(cond={}, then={}{})", + describe_expr(package, *cond), + describe_expr(package, *then_e), + else_str + ) + } + ExprKind::Block(_) => format!("Block[ty={}]", expr.ty), + ExprKind::Lit(lit) => format!("Lit({lit})"), + ExprKind::Var(_, _) => format!("Var[ty={}]", expr.ty), + ExprKind::Call(_, _) => format!("Call[ty={}]", expr.ty), + ExprKind::Tuple(es) => format!("Tuple(len={})", es.len()), + ExprKind::Assign(_, _) => "Assign".to_string(), + ExprKind::While(_, _) => format!("While[ty={}]", expr.ty), + ExprKind::BinOp(op, _, _) => format!("BinOp({op:?})[ty={}]", expr.ty), + ExprKind::UnOp(op, _) => format!("UnOp({op:?})[ty={}]", expr.ty), + _ => crate::test_utils::expr_kind_short(package, expr_id).clone(), + } +} + +fn try_eval_fir_entry( + store: &qsc_fir::fir::PackageStore, + pkg_id: qsc_fir::fir::PackageId, +) -> Result { + use qsc_eval::backend::{SparseSim, TracingBackend}; + use qsc_eval::output::GenericReceiver; + use qsc_fir::fir::ExecGraphConfig; + + let package = store.get(pkg_id); + let entry_graph = package.entry_exec_graph.clone(); + let mut env = qsc_eval::Env::default(); + let mut sim = SparseSim::new(); + let mut out = Vec::::new(); + let mut receiver = GenericReceiver::new(&mut out); + qsc_eval::eval( + pkg_id, + Some(42), + entry_graph, + ExecGraphConfig::NoDebug, + store, + &mut env, + &mut TracingBackend::no_tracer(&mut sim), + &mut receiver, + ) + .map_err(|(err, _frames)| format!("{err:?}")) +} + +/// Compiles Q# source to FIR using a single lowerer (matching the +/// `qsc_eval` test pattern), and evaluates the entry exec graph. +/// +/// The FIR has no transforms applied — this captures the original program +/// semantics. +fn eval_qsharp_original(source: &str) -> Result { + use qsc_frontend::compile as frontend_compile; + use qsc_hir::hir::PackageId; + use qsc_lowerer::map_hir_package_to_fir; + use qsc_passes::{PackageType, run_core_passes, run_default_passes}; + + let mut lowerer = qsc_lowerer::Lowerer::new(); + let mut core = frontend_compile::core(); + run_core_passes(&mut core); + let fir_store = qsc_fir::fir::PackageStore::new(); + let core_fir = lowerer.lower_package(&core.package, &fir_store); + let mut hir_store = qsc_frontend::compile::PackageStore::new(core); + + let mut std = frontend_compile::std(&hir_store, TargetCapabilityFlags::empty()); + assert!(std.errors.is_empty()); + assert!(run_default_passes(hir_store.core(), &mut std, PackageType::Lib).is_empty()); + let std_fir = lowerer.lower_package(&std.package, &fir_store); + let std_id = hir_store.insert(std); + + let sources = SourceMap::new(vec![("test.qs".into(), source.into())], None); + let mut unit = frontend_compile::compile( + &hir_store, + &[(PackageId::CORE, None), (std_id, None)], + sources, + TargetCapabilityFlags::empty(), + LanguageFeatures::default(), + ); + assert!(unit.errors.is_empty(), "{:?}", unit.errors); + let pass_errors = run_default_passes(hir_store.core(), &mut unit, PackageType::Exe); + assert!(pass_errors.is_empty(), "{pass_errors:?}"); + let unit_fir = lowerer.lower_package(&unit.package, &fir_store); + let user_hir_id = hir_store.insert(unit); + + let mut fir_store = qsc_fir::fir::PackageStore::new(); + fir_store.insert(map_hir_package_to_fir(PackageId::CORE), core_fir); + fir_store.insert(map_hir_package_to_fir(std_id), std_fir); + fir_store.insert(map_hir_package_to_fir(user_hir_id), unit_fir); + + try_eval_fir_entry(&fir_store, map_hir_package_to_fir(user_hir_id)) +} + +/// Compiles Q# source, runs the full FIR transform pipeline (including +/// `return_unify` and `exec_graph_rebuild`), and evaluates the entry exec +/// graph. +fn eval_qsharp_transformed(source: &str) -> Result { + let (store, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::Full); + try_eval_fir_entry(&store, pkg_id) +} + +/// Asserts semantic equivalence of a Q# program before and after the +/// full FIR transform pipeline. +/// +/// 1. Compiles the original Q# source (no transforms) and evaluates it to +/// get the expected return value. +/// 2. Compiles and runs the full FIR pipeline (including `return_unify`), +/// then evaluates to get the actual return value. +/// 3. Asserts the two results match (both succeed with equal values, or +/// both fail). +fn check_semantic_equivalence(source: &str) { + let expected = eval_qsharp_original(source); + let actual = eval_qsharp_transformed(source); + + match (&expected, &actual) { + (Ok(exp_val), Ok(act_val)) => { + assert_eq!( + exp_val, act_val, + "semantic equivalence violated: original returned {exp_val}, \ + transformed returned {act_val}" + ); + } + (Err(exp_err), Err(act_err)) => { + assert_eq!( + exp_err, act_err, + "semantic equivalence violated: original failed with {exp_err}, transformed failed with {act_err}" + ); + } + (Ok(exp_val), Err(err)) => { + panic!("original succeeded with {exp_val} but transformed failed: {err}"); + } + (Err(err), Ok(act_val)) => { + panic!("original failed with {err} but transformed succeeded with {act_val}"); + } + } +} + +fn check_idempotency(source: &str) { + let (mut store, pkg_id) = compile_return_unified(source); + + // Snapshot arena sizes before the second run. + let before = format!("{:?}", Assigner::from_package(store.get(pkg_id))); + + // Run unify_returns a second time. + let mut assigner = Assigner::from_package(store.get(pkg_id)); + let errors = super::unify_returns(&mut store, pkg_id, &mut assigner); + assert!( + errors.is_empty(), + "second unify_returns pass produced errors: {errors:?}" + ); + + // Snapshot arena sizes after the second run — should be identical. + let after = format!("{:?}", Assigner::from_package(store.get(pkg_id))); + assert_eq!( + before, after, + "second unify_returns pass allocated new nodes (not idempotent)" + ); +} diff --git a/source/compiler/qsc_fir_transforms/src/return_unify/tests/contracts_and_errors.rs b/source/compiler/qsc_fir_transforms/src/return_unify/tests/contracts_and_errors.rs new file mode 100644 index 0000000000..fb9a0bb91a --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/return_unify/tests/contracts_and_errors.rs @@ -0,0 +1,395 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use qsc_data_structures::span::Span; + +use crate::fir_builder::alloc_expr_stmt; + +use super::*; + +#[test] +#[should_panic(expected = "Unit-typed inner stmt")] +fn guard_stmt_with_flag_rejects_non_unit_expr_stmt() { + let source = indoc! {r#" + namespace Test { + @EntryPoint() + operation Main() : Unit {} + } + "#}; + + let (mut store, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::Mono); + let mut assigner = Assigner::from_package(store.get(pkg_id)); + let package = store.get_mut(pkg_id); + + let lit_expr_id = assigner.next_expr(); + package.exprs.insert( + lit_expr_id, + Expr { + id: lit_expr_id, + span: qsc_data_structures::span::Span::default(), + ty: Ty::Prim(Prim::Int), + kind: ExprKind::Lit(Lit::Int(0)), + exec_graph_range: crate::EMPTY_EXEC_RANGE, + }, + ); + + let stmt_id = { + let assigner: &mut Assigner = &mut assigner; + alloc_expr_stmt(package, assigner, lit_expr_id, Span::default()) + }; + let reachable = FxHashSet::default(); + let udt_pure_tys = super::super::build_scoped_udt_pure_ty_cache(&store, &reachable); + let package = store.get_mut(pkg_id); + let _ = super::super::guard_stmt_with_flag( + package, + &mut assigner, + pkg_id, + stmt_id, + LocalVarId(0), + &udt_pure_tys, + ); +} + +#[test] +fn flag_trailing_without_trailing_expr_uses_return_slot_fallback() { + let source = indoc! {r#" + namespace Test { + @EntryPoint() + operation Main() : Unit {} + } + "#}; + + let (mut store, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::Mono); + let mut assigner = Assigner::from_package(store.get(pkg_id)); + let package = store.get_mut(pkg_id); + + let mut stmts = Vec::new(); + let stmt_id = super::super::create_flag_trailing_expr( + package, + &mut assigner, + &mut stmts, + LocalVarId(0), + LocalVarId(1), + &Ty::Prim(Prim::Int), + ) + .expect("trailing merge statement should be created"); + + let StmtKind::Expr(if_expr_id) = package.get_stmt(stmt_id).kind else { + panic!("expected trailing merge expression statement"); + }; + assert_eq!(package.get_expr(if_expr_id).ty, Ty::Prim(Prim::Int)); +} + +#[test] +fn unsupported_return_slot_default_in_flag_strategy_produces_error() { + let source = indoc! {r#" + namespace Test { + operation Foo(q : Qubit) : Qubit { + mutable i = 0; + while i < 1 { + return q; + } + q + } + + operation Main() : Unit { + use q = Qubit(); + let _ = Foo(q); + Reset(q); + } + } + "#}; + + let (_store, _pkg_id, errors) = + compile_and_run_pipeline_to_with_errors(source, PipelineStage::ReturnUnify); + + assert!( + !errors.is_empty(), + "expected an UnsupportedLoopReturnType error for Qubit return in while" + ); + assert!( + errors.iter().any(|e| e.to_string().contains("Qubit")), + "error should mention the Qubit type, got: {:?}", + errors.iter().map(ToString::to_string).collect::>() + ); +} + +#[test] +#[should_panic(expected = "flag-strategy guarded Local initializer requires a classical default")] +fn unsupported_guarded_local_default_in_flag_strategy_is_explicit_contract() { + let source = indoc! {r#" + namespace Test { + operation Main() : Int { + mutable i = 0; + while i < 1 { + return 1; + } + use q = Qubit(); + 0 + } + } + "#}; + + let _ = compile_and_run_pipeline_to(source, PipelineStage::ReturnUnify); +} + +#[test] +fn qubit_return_in_while_produces_error() { + let source = indoc! {r#" + namespace Test { + operation Main() : Qubit { + use q = Qubit(); + mutable i = 0; + while i < 5 { + if i == 3 { + return q; + } + i += 1; + } + q + } + } + "#}; + + let (_store, _pkg_id, errors) = + compile_and_run_pipeline_to_with_errors(source, PipelineStage::ReturnUnify); + + assert!( + !errors.is_empty(), + "expected an UnsupportedLoopReturnType error for Qubit return in while" + ); + assert!( + errors.iter().any(|e| e.to_string().contains("Qubit")), + "error should mention the Qubit type, got: {:?}", + errors.iter().map(ToString::to_string).collect::>() + ); +} + +#[test] +#[allow(clippy::too_many_lines)] +fn test_reachable_only_transformation() { + // Arrange: Create a package with one reachable callable (called from Main) + // with a return statement, and one unreachable callable (never called) with + // a return statement. The reachable callable should be normalized; the + // unreachable one should remain unchanged. + let source = indoc! {r#" + namespace Test { + // Reachable callable that needs return normalization + function Process(x : Int) : Int { + if x > 0 { + return x * 2; + } + x + 1 + } + + // Unreachable callable (never called) - should not be transformed + function UnusedHelper(x : Int) : Int { + if x > 0 { + return x * 3; + } + x + 2 + } + + // Entry point - only calls Process, not UnusedHelper + @EntryPoint() + function Main() : Int { + Process(5) + } + } + "#}; + + // Act: Compile through FIR to capture before state, then run full pipeline + let (before_store, before_pkg_id) = compile_to_fir(source); + let before_package = before_store.get(before_pkg_id); + + // Verify UnusedHelper has returns before transformation + let mut before_unused_has_return = false; + { + let unused_item = before_package + .items + .values() + .find(|item| { + matches!( + &item.kind, + ItemKind::Callable(decl) if decl.name.name.as_ref() == "UnusedHelper" + ) + }) + .expect("UnusedHelper should exist"); + + if let ItemKind::Callable(decl) = &unused_item.kind { + for_each_expr_in_callable_impl( + before_package, + &decl.implementation, + &mut |_id, expr| { + before_unused_has_return |= matches!(expr.kind, ExprKind::Return(_)); + }, + ); + } + } + assert!( + before_unused_has_return, + "UnusedHelper should have Return nodes before transformation" + ); + + // Now run return_unify through the full pipeline + let (after_store, after_pkg_id) = compile_return_unified(source); + let after_package = after_store.get(after_pkg_id); + let after_reachable = collect_reachable_from_entry(&after_store, after_pkg_id); + + // Assert: Verify reachable callable (Process) has no returns after transformation + let mut process_has_return = false; + { + let process_item = after_package + .items + .values() + .find(|item| { + matches!( + &item.kind, + ItemKind::Callable(decl) if decl.name.name.as_ref() == "Process" + ) + }) + .expect("Process should exist"); + + if let ItemKind::Callable(decl) = &process_item.kind { + for_each_expr_in_callable_impl( + after_package, + &decl.implementation, + &mut |_id, expr| { + process_has_return |= matches!(expr.kind, ExprKind::Return(_)); + }, + ); + } + } + assert!( + !process_has_return, + "Reachable Process callable should have no Return nodes after return_unify (reachable-only contract)" + ); + + // Assert: Verify unreachable callable (UnusedHelper) was NOT transformed + // and still has returns (documenting the reachable-only semantics) + let mut unused_has_return = false; + { + let unused_item = after_package + .items + .values() + .find(|item| { + matches!( + &item.kind, + ItemKind::Callable(decl) if decl.name.name.as_ref() == "UnusedHelper" + ) + }) + .expect("UnusedHelper should exist"); + + if let ItemKind::Callable(decl) = &unused_item.kind { + for_each_expr_in_callable_impl( + after_package, + &decl.implementation, + &mut |_id, expr| { + unused_has_return |= matches!(expr.kind, ExprKind::Return(_)); + }, + ); + } + } + assert!( + unused_has_return, + "Unreachable UnusedHelper callable should retain Return nodes after return_unify (reachable-only contract)\n\ + INVARIANT: Later passes must not resurrect dead callables after return_unify scopes its transformation to reachable code" + ); + + // Verify it's not in the reachable set + let is_unused_reachable = after_reachable.iter().any(|store_id| { + if store_id.package != after_pkg_id { + return false; + } + let item = after_package.get_item(store_id.item); + matches!( + &item.kind, + ItemKind::Callable(decl) if decl.name.name.as_ref() == "UnusedHelper" + ) + }); + assert!( + !is_unused_reachable, + "UnusedHelper must not be in the reachable set" + ); +} + +/// Verify that dead code with return statements is not transformed by `return_unify`, +/// even if it would benefit from normalization. This documents that `return_unify` +/// strictly scopes its transformation to reachable code. + +#[test] +fn test_unreachable_callables_untouched() { + // Arrange: Create a package where a dead callable has return statements + // that would normally need transformation if it were reachable. + let source = indoc! {r#" + namespace Test { + // Dead callable with multiple returns that would trigger flag-based + // transformation if it were reachable (due to nested control flow) + function DeadCode(x : Int) : Int { + mutable i = 0; + while i < 5 { + if x == i { + return i; + } + i += 1; + } + -1 + } + + // Entry point that never calls DeadCode + @EntryPoint() + function Main() : Int { + 42 + } + } + "#}; + + // Act: Compile and run return_unify + let (store, pkg_id) = compile_return_unified(source); + let package = store.get(pkg_id); + let reachable = collect_reachable_from_entry(&store, pkg_id); + + // Assert: Verify DeadCode is not in reachable set + let is_deadcode_reachable = reachable.iter().any(|store_id| { + if store_id.package != pkg_id { + return false; + } + let item = package.get_item(store_id.item); + matches!( + &item.kind, + ItemKind::Callable(decl) if decl.name.name.as_ref() == "DeadCode" + ) + }); + assert!( + !is_deadcode_reachable, + "DeadCode should not be in reachable set" + ); + + // Assert: Verify DeadCode still has Return nodes (was not transformed) + // This is the core contract: unreachable code is left untouched + let mut deadcode_has_return = false; + { + let deadcode_item = package + .items + .values() + .find(|item| { + matches!( + &item.kind, + ItemKind::Callable(decl) if decl.name.name.as_ref() == "DeadCode" + ) + }) + .expect("DeadCode should exist"); + + if let ItemKind::Callable(decl) = &deadcode_item.kind { + for_each_expr_in_callable_impl(package, &decl.implementation, &mut |_id, expr| { + deadcode_has_return |= matches!(expr.kind, ExprKind::Return(_)); + }); + } + } + assert!( + deadcode_has_return, + "Unreachable DeadCode should retain Return nodes (reachable-only transformation contract)\n\ + CRITICAL INVARIANT: return_unify must not transform unreachable code, as later passes assume \ + only reachable code is normalized. Any resurrections of dead code would violate the no-return invariant." + ); +} diff --git a/source/compiler/qsc_fir_transforms/src/return_unify/tests/flag_strategy.rs b/source/compiler/qsc_fir_transforms/src/return_unify/tests/flag_strategy.rs new file mode 100644 index 0000000000..fe372cdc76 --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/return_unify/tests/flag_strategy.rs @@ -0,0 +1,1352 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use super::*; +use crate::fir_builder::functored_specs; + +#[test] +fn return_inside_while_loop() { + // Flag-based transformation with `__has_returned` and `__ret_val`. + check_no_returns_q( + indoc! {r#" + namespace Test { + function Main() : Int { + mutable i = 0; + while i < 10 { + if i == 5 { + return i; + } + i += 1; + } + -1 + } + } + "#}, + &expect![[r#" + // namespace Test + function Main() : Int { + body { + mutable __has_returned : Bool = false; + mutable __ret_val : Int = 0; + mutable i : Int = 0; + while not __has_returned and i < 10 { + if i == 5 { + { + __ret_val = i; + __has_returned = true; + }; + } + + if not __has_returned { + i += 1; + }; + } + + let __trailing_result : Int = -1; + if __has_returned __ret_val else __trailing_result + } + } + // entry + Main() + "#]], + ); +} + +#[test] +fn while_return_tuple_value_uses_flag_fallback() { + let source = indoc! {r#" + namespace Test { + function Main() : (Int, Bool) { + mutable i = 0; + while i < 3 { + if i == 1 { + return (i, true); + } + i += 1; + } + (-1, false) + } + } + "#}; + + check_no_returns_q( + source, + &expect![[r#" + // namespace Test + function Main() : (Int, Bool) { + body { + mutable __has_returned : Bool = false; + mutable __ret_val : (Int, Bool) = (0, false); + mutable i : Int = 0; + while not __has_returned and i < 3 { + if i == 1 { + { + __ret_val = (i, true); + __has_returned = true; + }; + } + + if not __has_returned { + i += 1; + }; + } + + let __trailing_result : (Int, Bool) = (-1, false); + if __has_returned __ret_val else __trailing_result + } + } + // entry + Main() + "#]], + ); + + let (store, pkg_id) = compile_return_unified(source); + let package = store.get(pkg_id); + let (pat, init_expr) = find_local_init(package, "Main", "__ret_val"); + + assert_eq!( + pat.ty, + Ty::Tuple(vec![Ty::Prim(Prim::Int), Ty::Prim(Prim::Bool)]) + ); + + let ExprKind::Tuple(items) = &init_expr.kind else { + panic!( + "expected tuple fallback initializer, got {:?}", + init_expr.kind + ); + }; + assert_eq!(items.len(), 2, "tuple fallback should preserve arity"); + assert_eq!(package.get_expr(items[0]).ty, Ty::Prim(Prim::Int)); + assert_eq!(package.get_expr(items[1]).ty, Ty::Prim(Prim::Bool)); +} + +#[test] +fn all_returning_nested_if_tuple_uses_return_slot_fallback() { + let source = indoc! {r#" + namespace Test { + function Touch() : Unit { () } + + function Main() : (Bool, (Int, Int)) { + let value = 3; + if value > 0 { + if value > 1 { + if value > 2 { + Touch(); + return (true, (value, value)); + } + } + Touch(); + return (false, (1, 1)); + } else { + Touch(); + return (false, (2, 2)); + } + } + } + "#}; + + let (store, pkg_id) = compile_return_unified(source); + let package = store.get(pkg_id); + let (ret_val_pat, _) = find_local_init(package, "Main", "__ret_val"); + let ret_val_var_id = local_var_id_from_named_pat(ret_val_pat, "__ret_val"); + + let body_block_id = find_body_block_id(package, "Main"); + let body_block = package.get_block(body_block_id); + let trailing_stmt_id = *body_block + .stmts + .last() + .expect("expected rewritten Main body to have a trailing expression"); + let StmtKind::Expr(trailing_expr_id) = &package.get_stmt(trailing_stmt_id).kind else { + panic!("expected rewritten Main body to end with trailing Expr") + }; + assert!( + expr_reads_local(package, *trailing_expr_id, ret_val_var_id), + "all-returning non-Unit block should use __ret_val as its final expression" + ); + + let has_trailing_result = body_block.stmts.iter().any(|stmt_id| { + let StmtKind::Local(_, pat_id, _) = package.get_stmt(*stmt_id).kind else { + return false; + }; + let pat = package.get_pat(pat_id); + matches!(&pat.kind, PatKind::Bind(ident) if ident.name.as_ref() == "__trailing_result") + }); + assert!( + !has_trailing_result, + "Unit trailing statements in all-returning non-Unit blocks must not be captured as __trailing_result" + ); +} + +#[test] +fn while_return_array_value_uses_flag_fallback() { + let source = indoc! {r#" + namespace Test { + function Main() : Int[] { + mutable i = 0; + while i < 3 { + if i == 1 { + return [i, i + 1]; + } + i += 1; + } + [] + } + } + "#}; + + check_no_returns_q( + source, + &expect![[r#" + // namespace Test + function Main() : Int[] { + body { + mutable __has_returned : Bool = false; + mutable __ret_val : Int[] = []; + mutable i : Int = 0; + while not __has_returned and i < 3 { + if i == 1 { + { + __ret_val = [i, i + 1]; + __has_returned = true; + }; + } + + if not __has_returned { + i += 1; + }; + } + + let __trailing_result : Int[] = []; + if __has_returned __ret_val else __trailing_result + } + } + // entry + Main() + "#]], + ); + + let (store, pkg_id) = compile_return_unified(source); + let package = store.get(pkg_id); + let (pat, init_expr) = find_local_init(package, "Main", "__ret_val"); + + assert_eq!(pat.ty, Ty::Array(Box::new(Ty::Prim(Prim::Int)))); + + let ExprKind::Array(items) = &init_expr.kind else { + panic!( + "expected array fallback initializer, got {:?}", + init_expr.kind + ); + }; + assert!( + items.is_empty(), + "array fallback should start from an empty array" + ); +} + +#[test] +fn while_local_initializer_if_return_is_rewritten_by_flag_strategy() { + let source = indoc! {r#" + namespace Test { + function Add(a : Int, b : Int) : Int { a + b } + function Main() : Int { + mutable i = 0; + while i < 3 { + let _ = if i == 1 { + Add((return 42), i) + }; + i += 1; + } + i + 5 + } + } + "#}; + + let (store, pkg_id) = compile_return_unified(source); + let package = store.get(pkg_id); + + let (has_returned_pat, _) = find_local_init(package, "Main", "__has_returned"); + let has_returned_var_id = local_var_id_from_named_pat(has_returned_pat, "__has_returned"); + let (ret_val_pat, _) = find_local_init(package, "Main", "__ret_val"); + let ret_val_var_id = local_var_id_from_named_pat(ret_val_pat, "__ret_val"); + + let body_block_id = find_body_block_id(package, "Main"); + let body_block = package.get_block(body_block_id); + + let (while_cond_id, while_body_block_id) = body_block + .stmts + .iter() + .find_map(|&stmt_id| { + let while_expr_id = match &package.get_stmt(stmt_id).kind { + StmtKind::Expr(expr_id) | StmtKind::Semi(expr_id) => *expr_id, + StmtKind::Local(_, _, _) | StmtKind::Item(_) => return None, + }; + let ExprKind::While(cond_id, body_id) = &package.get_expr(while_expr_id).kind else { + return None; + }; + Some((*cond_id, *body_id)) + }) + .expect("expected Main body to contain a rewritten while loop"); + + assert_while_condition_guarded_by_not_flag(package, while_cond_id, has_returned_var_id); + + let while_block = package.get_block(while_body_block_id); + let local_init_expr_id = while_block + .stmts + .iter() + .find_map(|&stmt_id| match &package.get_stmt(stmt_id).kind { + StmtKind::Local(_, _, init_expr_id) => Some(*init_expr_id), + StmtKind::Expr(_) | StmtKind::Semi(_) | StmtKind::Item(_) => None, + }) + .expect("expected while body to keep a Local initializer statement"); + + let local_order_pinned = assert_local_initializer_then_assign_order( + package, + local_init_expr_id, + ret_val_var_id, + has_returned_var_id, + ); + if !local_order_pinned { + assert_callable_assign_order(package, "Main", ret_val_var_id, has_returned_var_id); + } + + let trailing_stmt_id = *body_block + .stmts + .last() + .expect("expected rewritten Main body to have a trailing expression"); + let StmtKind::Expr(trailing_expr_id) = &package.get_stmt(trailing_stmt_id).kind else { + panic!("expected rewritten Main body to end with trailing Expr") + }; + let ExprKind::If(flag_expr_id, then_expr_id, Some(else_expr_id)) = + &package.get_expr(*trailing_expr_id).kind + else { + panic!("expected trailing merge expression to be if __has_returned ...") + }; + + assert!( + expr_reads_local(package, *flag_expr_id, has_returned_var_id), + "trailing merge condition should read __has_returned" + ); + assert!( + expr_reads_local(package, *then_expr_id, ret_val_var_id), + "trailing merge then-branch should read __ret_val" + ); + + // After the bind-then-check fix, the else branch reads __trailing_result (a Var) + // rather than the original fallthrough expression directly. + assert!( + matches!( + &package.get_expr(*else_expr_id).kind, + ExprKind::Var(Res::Local(_), _) + ), + "trailing merge else-branch should read __trailing_result" + ); +} + +#[allow(clippy::too_many_lines)] +#[test] +fn while_local_initializer_if_else_return_preserves_fallthrough_tail() { + let source = indoc! {r#" + namespace Test { + function Add(a : Int, b : Int) : Int { a + b } + function Main() : Int { + mutable i = 0; + while i < 3 { + let x = if i == 1 { + Add((return 7), i) + } else { + i + 10 + }; + i += x; + } + let tail = i + 5; + tail + } + } + "#}; + + let (store, pkg_id) = compile_return_unified(source); + let package = store.get(pkg_id); + + let (has_returned_pat, _) = find_local_init(package, "Main", "__has_returned"); + let has_returned_var_id = local_var_id_from_named_pat(has_returned_pat, "__has_returned"); + let (ret_val_pat, _) = find_local_init(package, "Main", "__ret_val"); + let ret_val_var_id = local_var_id_from_named_pat(ret_val_pat, "__ret_val"); + + let body_block_id = find_body_block_id(package, "Main"); + let body_block = package.get_block(body_block_id); + + let (while_cond_id, while_body_block_id) = body_block + .stmts + .iter() + .find_map(|&stmt_id| { + let while_expr_id = match &package.get_stmt(stmt_id).kind { + StmtKind::Expr(expr_id) | StmtKind::Semi(expr_id) => *expr_id, + StmtKind::Local(_, _, _) | StmtKind::Item(_) => return None, + }; + let ExprKind::While(cond_id, body_id) = &package.get_expr(while_expr_id).kind else { + return None; + }; + Some((*cond_id, *body_id)) + }) + .expect("expected Main body to contain a rewritten while loop"); + + assert_while_condition_guarded_by_not_flag(package, while_cond_id, has_returned_var_id); + + let while_block = package.get_block(while_body_block_id); + let x_local_init_expr_id = while_block + .stmts + .iter() + .find_map(|&stmt_id| { + let StmtKind::Local(_, pat_id, init_expr_id) = &package.get_stmt(stmt_id).kind else { + return None; + }; + let pat = package.get_pat(*pat_id); + let PatKind::Bind(ident) = &pat.kind else { + return None; + }; + (ident.name.as_ref() == "x").then_some(*init_expr_id) + }) + .expect("expected while body to contain Local x initializer"); + + let local_order_pinned = assert_local_initializer_then_assign_order( + package, + x_local_init_expr_id, + ret_val_var_id, + has_returned_var_id, + ); + if !local_order_pinned { + assert_callable_assign_order(package, "Main", ret_val_var_id, has_returned_var_id); + } + + let (_tail_var_id, tail_init_expr_id) = body_block + .stmts + .iter() + .find_map(|&stmt_id| { + let StmtKind::Local(_, pat_id, init_expr_id) = &package.get_stmt(stmt_id).kind else { + return None; + }; + let pat = package.get_pat(*pat_id); + let PatKind::Bind(ident) = &pat.kind else { + return None; + }; + (ident.name.as_ref() == "tail").then_some((ident.id, *init_expr_id)) + }) + .expect("expected Main body to contain guarded tail local"); + + let ExprKind::If(guard_cond_id, _then_expr_id, Some(else_expr_id)) = + &package.get_expr(tail_init_expr_id).kind + else { + panic!("tail initializer should be guarded by if not __has_returned") + }; + assert!( + is_not_flag_expr(package, *guard_cond_id, has_returned_var_id), + "tail initializer guard should be not __has_returned" + ); + + let guard_else_kind = &package.get_expr(*else_expr_id).kind; + let guard_else_is_int_zero = if matches!(guard_else_kind, ExprKind::Lit(Lit::Int(0))) { + true + } else if let ExprKind::Block(block_id) = guard_else_kind { + let block = package.get_block(*block_id); + match block.stmts.last() { + Some(last_stmt_id) => matches!( + &package.get_stmt(*last_stmt_id).kind, + StmtKind::Expr(expr_id) + if matches!(&package.get_expr(*expr_id).kind, ExprKind::Lit(Lit::Int(0))) + ), + None => false, + } + } else { + false + }; + + assert!( + guard_else_is_int_zero, + "guarded Int local fallback should synthesize 0 in else-branch" + ); + + let trailing_stmt_id = *body_block + .stmts + .last() + .expect("expected rewritten Main body to have a trailing expression"); + let StmtKind::Expr(trailing_expr_id) = &package.get_stmt(trailing_stmt_id).kind else { + panic!("expected rewritten Main body to end with trailing Expr") + }; + let ExprKind::If(flag_expr_id, then_expr_id, Some(else_expr_id)) = + &package.get_expr(*trailing_expr_id).kind + else { + panic!("expected trailing merge expression to be if __has_returned ...") + }; + + assert!( + expr_reads_local(package, *flag_expr_id, has_returned_var_id), + "trailing merge condition should read __has_returned" + ); + assert!( + expr_reads_local(package, *then_expr_id, ret_val_var_id), + "trailing merge then-branch should read __ret_val" + ); + + // After the bind-then-check fix, the else branch reads __trailing_result rather than + // the `tail` local directly. + let (trailing_result_pat, _) = find_local_init(package, "Main", "__trailing_result"); + let trailing_result_var_id = + local_var_id_from_named_pat(trailing_result_pat, "__trailing_result"); + assert!( + expr_reads_local(package, *else_expr_id, trailing_result_var_id), + "trailing merge else-branch should read __trailing_result" + ); +} + +#[test] +fn nested_loop_exit_convergence_is_guarded_by_flag() { + let source = indoc! {r#" + namespace Test { + function Main() : Int { + mutable outer = 0; + mutable inner = 0; + while outer < 2 { + while inner < 2 { + if inner == 1 { + return outer + inner; + } + inner += 1; + } + outer += 1; + inner = 0; + } + -1 + } + } + "#}; + + let (store, pkg_id) = compile_return_unified(source); + let rendered = crate::pretty::write_package_qsharp(&store, pkg_id); + + assert!( + rendered.contains("while not __has_returned and outer < 2"), + "outer loop exit convergence must be guarded by __has_returned", + ); + assert!( + rendered.contains("while not __has_returned and inner < 2"), + "inner loop exit convergence must be guarded by __has_returned", + ); + assert!( + !rendered.contains("while inner < 2 {"), + "inner loop should not remain unguarded after return unification", + ); +} + +#[test] +fn lowered_reachable_callables_do_not_emit_while_local_initializers() { + let source = indoc! {r#" + namespace Test { + function Helper(flag : Bool) : Int { + mutable i = 0; + while i < 3 { + let x = if flag { + i + } else { + i + 1 + }; + i += x; + } + i + } + + @EntryPoint() + function Main() : Int { + let seed = 1; + seed + Helper(true) + } + } + "#}; + + let (store, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::Mono); + let package = store.get(pkg_id); + let reachable = collect_reachable_from_entry(&store, pkg_id); + let mut offenders = Vec::new(); + + for store_id in &reachable { + if store_id.package != pkg_id { + continue; + } + + let item = package.get_item(store_id.item); + let ItemKind::Callable(decl) = &item.kind else { + continue; + }; + + let mut block_ids = Vec::new(); + match &decl.implementation { + CallableImpl::Spec(spec_impl) => { + block_ids.push(spec_impl.body.block); + for spec in functored_specs(spec_impl) { + block_ids.push(spec.block); + } + } + CallableImpl::SimulatableIntrinsic(spec) => { + block_ids.push(spec.block); + } + CallableImpl::Intrinsic => {} + } + + for_each_expr_in_callable_impl(package, &decl.implementation, &mut |_expr_id, expr| { + if let ExprKind::Block(block_id) | ExprKind::While(_, block_id) = expr.kind { + block_ids.push(block_id); + } + }); + + block_ids.sort_unstable_by_key(|block_id| block_id.0); + block_ids.dedup(); + + for block_id in block_ids { + let block = package.get_block(block_id); + for &stmt_id in &block.stmts { + let StmtKind::Local(_, pat_id, init_expr_id) = package.get_stmt(stmt_id).kind + else { + continue; + }; + + if !matches!(package.get_expr(init_expr_id).kind, ExprKind::While(_, _)) { + continue; + } + + let pat = package.get_pat(pat_id); + let pat_desc = match &pat.kind { + PatKind::Bind(ident) => ident.name.to_string(), + PatKind::Tuple(_) => "".to_string(), + PatKind::Discard => "_".to_string(), + }; + + offenders.push(format!( + "{}: block {block_id}, stmt {stmt_id}, pat {pat_desc}", + decl.name.name + )); + } + } + } + + assert!( + offenders.is_empty(), + "entry-reachable lowered FIR should not contain Local initializers with while expressions; found: {}", + offenders.join("; ") + ); +} + +#[test] +fn synthetic_while_local_initializer_shape_still_eliminates_returns() { + let source = indoc! {r#" + namespace Test { + @EntryPoint() + function Main() : Int { + let marker = (); + mutable i = 0; + while i < 2 { + if i == 1 { + return 9; + } + i += 1; + } + 0 + } + } + "#}; + + let (mut store, pkg_id) = compile_to_fir(source); + + let (marker_stmt_id, while_expr_id) = { + let package = store.get(pkg_id); + let body_block_id = find_body_block_id(package, "Main"); + let body_block = package.get_block(body_block_id); + + let marker_stmt_id = body_block + .stmts + .iter() + .copied() + .find(|stmt_id| { + let StmtKind::Local(_, pat_id, _init_expr_id) = package.get_stmt(*stmt_id).kind + else { + return false; + }; + let pat = package.get_pat(pat_id); + matches!(&pat.kind, PatKind::Bind(ident) if ident.name.as_ref() == "marker") + }) + .expect("expected Main body to contain local 'marker'"); + + let while_expr_id = body_block + .stmts + .iter() + .find_map(|stmt_id| match package.get_stmt(*stmt_id).kind { + StmtKind::Expr(expr_id) | StmtKind::Semi(expr_id) + if matches!(package.get_expr(expr_id).kind, ExprKind::While(_, _)) => + { + Some(expr_id) + } + _ => None, + }) + .expect("expected Main body to contain a while statement expression"); + + (marker_stmt_id, while_expr_id) + }; + + let mut assigner = Assigner::from_package(store.get(pkg_id)); + { + let package = store.get_mut(pkg_id); + let while_expr = package.get_expr(while_expr_id).clone(); + let synthetic_while_expr_id = assigner.next_expr(); + package.exprs.insert( + synthetic_while_expr_id, + Expr { + id: synthetic_while_expr_id, + ..while_expr + }, + ); + + let marker_stmt = package + .stmts + .get_mut(marker_stmt_id) + .expect("marker stmt should exist"); + let StmtKind::Local(mutability, pat_id, _) = marker_stmt.kind else { + panic!("marker stmt should remain a Local after lookup") + }; + marker_stmt.kind = StmtKind::Local(mutability, pat_id, synthetic_while_expr_id); + + assert!( + matches!( + package.get_expr(synthetic_while_expr_id).kind, + ExprKind::While(_, _) + ), + "synthetic setup should place a while expression in Local initializer" + ); + } + + let errors = crate::run_pipeline_to(&mut store, pkg_id, PipelineStage::ReturnUnify, &[]); + assert!( + errors.is_empty(), + "return_unify pipeline should complete on synthetic while-local-initializer shape" + ); + + let package = store.get(pkg_id); + let reachable = collect_reachable_from_entry(&store, pkg_id); + for store_id in &reachable { + if store_id.package != pkg_id { + continue; + } + let item = package.get_item(store_id.item); + if let ItemKind::Callable(decl) = &item.kind { + for_each_expr_in_callable_impl(package, &decl.implementation, &mut |_id, expr| { + assert!( + !matches!(expr.kind, ExprKind::Return(_)), + "synthetic while-local-initializer shape should still satisfy PostReturnUnify no-return invariant" + ); + }); + } + } +} + +#[test] +fn while_body_call_arg_return_keeps_loop_before_trailing_merge() { + check_structure( + indoc! {r#" + namespace Test { + function Add(a : Int, b : Int) : Int { a + b } + + function Main() : Int { + mutable i = 0; + while i < 3 { + let _ = Add((return 42), 2); + i += 1; + } + -1 + } + } + "#}, + &["Main"], + &expect![[r#" + callable Main: input_ty=Unit, output_ty=Int + body: block_ty=Int + [0] Local(Mutable, __has_returned: Bool): Lit(Bool(false)) + [1] Local(Mutable, __ret_val: Int): Lit(Int(0)) + [2] Local(Mutable, i: Int): Lit(Int(0)) + [3] Expr While[ty=Unit] + [4] Local(Immutable, __trailing_result: Int): UnOp(Neg)[ty=Int] + [5] Expr If(cond=Var[ty=Bool], then=Var[ty=Int], else=Var[ty=Int])"#]], + ); +} + +#[test] +fn nested_block_with_while_return_not_transformable_by_if_else() { + // For-loop desugaring wraps a While in a Block. When transform_block_if_else + // encounters this NestedBlock, the inner block contains a While-with-return + // that it can't handle (While falls to ReturnClass::None). The !changed + // guard must return false to prevent infinite recursion. + // + // This test calls transform_block_if_else directly (bypassing unify_returns + // which would route to the flag-based path) to exercise the guard. + let (mut store, pkg_id) = compile_and_run_pipeline_to( + indoc! {r#" + namespace Test { + function Main() : Int { + for i in 0..10 { + if i == 5 { + return i; + } + } + -1 + } + } + "#}, + PipelineStage::Mono, + ); + + let package = store.get(pkg_id); + let reachable = collect_reachable_from_entry(&store, pkg_id); + + // Find Main's body block and return type. + let (block_id, return_ty) = reachable + .iter() + .filter(|id| id.package == pkg_id) + .find_map(|id| { + let item = package.get_item(id.item); + if let ItemKind::Callable(decl) = &item.kind + && let CallableImpl::Spec(spec) = &decl.implementation + { + return Some((spec.body.block, decl.output.clone())); + } + None + }) + .expect("Main callable not found"); + + let mut assigner = Assigner::from_package(package); + + let package = store.get_mut(pkg_id); + + // transform_block_if_else should return false because the nested block + // contains a while-with-return that requires the flag-based transform. + let changed = + super::super::transform_block_if_else(package, &mut assigner, block_id, &return_ty); + assert!( + !changed, + "transform_block_if_else should return false for nested block containing while-with-return", + ); +} + +#[test] +fn range_return_default_in_flag_strategy_is_supported() { + let source = indoc! {r#" + namespace Test { + function Main() : Range { + mutable i = 0; + while i < 1 { + return 0..1; + } + 2..3 + } + } + "#}; + + let (store, pkg_id) = compile_return_unified(source); + let rendered = crate::pretty::write_package_qsharp(&store, pkg_id); + + assert!( + rendered.contains("mutable __ret_val : Range ="), + "flag strategy should synthesize a default Range return slot", + ); + assert!( + rendered.contains("if __has_returned __ret_val else"), + "final trailing expression should select between captured return and fallthrough", + ); +} + +#[test] +fn tuple_return_in_while_with_nested_if() { + check_no_returns_q( + indoc! {r#" + namespace Test { + function Main() : (Int, Bool) { + mutable i = 0; + while i < 10 { + if i > 5 { + if i == 7 { + return (i, true); + } + } + i += 1; + } + (-1, false) + } + } + "#}, + &expect![[r#" + // namespace Test + function Main() : (Int, Bool) { + body { + mutable __has_returned : Bool = false; + mutable __ret_val : (Int, Bool) = (0, false); + mutable i : Int = 0; + while not __has_returned and i < 10 { + if i > 5 { + if i == 7 { + { + __ret_val = (i, true); + __has_returned = true; + }; + } + + } + + if not __has_returned { + i += 1; + }; + } + + let __trailing_result : (Int, Bool) = (-1, false); + if __has_returned __ret_val else __trailing_result + } + } + // entry + Main() + "#]], + ); +} + +#[test] +#[allow(clippy::too_many_lines)] +fn all_four_specializations_with_return_in_loop() { + check_no_returns_q( + indoc! {r#" + namespace Test { + operation Op(q : Qubit) : Unit is Adj + Ctl { + body ... { + mutable i = 0; + while i < 5 { + if i == 3 { + return (); + } + i += 1; + } + () + } + adjoint ... { + mutable j = 0; + while j < 5 { + if j == 2 { + return (); + } + j += 1; + } + () + } + controlled (cs, ...) { + mutable k = 0; + while k < 5 { + if k == 4 { + return (); + } + k += 1; + } + () + } + controlled adjoint (cs, ...) { + mutable m = 0; + while m < 5 { + if m == 1 { + return (); + } + m += 1; + } + () + } + } + operation Main() : Unit { + use q = Qubit(); + Op(q) + } + } + "#}, + &expect![[r#" + // namespace Test + operation Op(q : Qubit) : Unit is Adj + Ctl { + body { + mutable __has_returned : Bool = false; + mutable __ret_val : Unit = (); + mutable i : Int = 0; + while not __has_returned and i < 5 { + if i == 3 { + { + __ret_val = (); + __has_returned = true; + }; + } + + if not __has_returned { + i += 1; + }; + } + + let __trailing_result : Unit = (); + if __has_returned __ret_val else __trailing_result + } + adjoint { + mutable __has_returned : Bool = false; + mutable __ret_val : Unit = (); + mutable j : Int = 0; + while not __has_returned and j < 5 { + if j == 2 { + { + __ret_val = (); + __has_returned = true; + }; + } + + if not __has_returned { + j += 1; + }; + } + + let __trailing_result : Unit = (); + if __has_returned __ret_val else __trailing_result + } + controlled { + mutable __has_returned : Bool = false; + mutable __ret_val : Unit = (); + mutable k : Int = 0; + while not __has_returned and k < 5 { + if k == 4 { + { + __ret_val = (); + __has_returned = true; + }; + } + + if not __has_returned { + k += 1; + }; + } + + let __trailing_result : Unit = (); + if __has_returned __ret_val else __trailing_result + } + controlled adjoint { + mutable __has_returned : Bool = false; + mutable __ret_val : Unit = (); + mutable m : Int = 0; + while not __has_returned and m < 5 { + if m == 1 { + { + __ret_val = (); + __has_returned = true; + }; + } + + if not __has_returned { + m += 1; + }; + } + + let __trailing_result : Unit = (); + if __has_returned __ret_val else __trailing_result + } + } + operation Main() : Unit { + body { + let q : Qubit = __quantum__rt__qubit_allocate(); + let + @generated_ident_142 : Unit = Op(q); + __quantum__rt__qubit_release(q); + @generated_ident_142 + } + } + // entry + Main() + "#]], + ); +} + +// Qubit alloc scope + flag strategy + +#[test] +fn qubit_alloc_scope_with_flag_strategy() { + check_no_returns_q( + indoc! {r#" + namespace Test { + operation Main() : Int { + mutable i = 0; + while i < 5 { + use q = Qubit(); + if i == 3 { + return i; + } + i += 1; + } + -1 + } + } + "#}, + &expect![[r#" + // namespace Test + operation Main() : Int { + body { + mutable __has_returned : Bool = false; + mutable __ret_val : Int = 0; + mutable i : Int = 0; + while not __has_returned and i < 5 { + let q : Qubit = __quantum__rt__qubit_allocate(); + if i == 3 { + { + let + @generated_ident_45 : Int = i; + __quantum__rt__qubit_release(q); + { + __ret_val = + @generated_ident_45; + __has_returned = true; + }; + }; + } + + if not __has_returned { + i += 1; + }; + if not __has_returned { + __quantum__rt__qubit_release(q); + }; + } + + let __trailing_result : Int = -1; + if __has_returned __ret_val else __trailing_result + } + } + // entry + Main() + "#]], + ); +} + +#[test] +fn repeat_until_with_return() { + check_no_returns_q( + indoc! {r#" + namespace Test { + operation Main() : Int { + mutable result = 0; + mutable attempt = 0; + repeat { + if attempt > 3 { + return -1; + } + attempt += 1; + result = attempt * 2; + } until result > 5; + result + } + } + "#}, + &expect![[r#" + // namespace Test + operation Main() : Int { + body { + mutable __has_returned : Bool = false; + mutable __ret_val : Int = 0; + mutable result : Int = 0; + mutable attempt : Int = 0; + { + mutable + @continue_cond_46 : Bool = true; + while not __has_returned and + @continue_cond_46 { + if attempt > 3 { + { + __ret_val = -1; + __has_returned = true; + }; + } + + if not __has_returned { + attempt += 1; + }; + if not __has_returned { + result = attempt * 2; + }; + if not __has_returned { + @continue_cond_46 = not result > 5; + }; + } + + }; + let __trailing_result : Int = result; + if __has_returned __ret_val else __trailing_result + } + } + // entry + Main() + "#]], + ); +} + +#[test] +fn while_body_side_effect_guarded_after_return() { + check_no_returns_q( + indoc! {r#" + namespace Test { + operation Main() : Int { + mutable sum = 0; + mutable i = 0; + while i < 10 { + if i == 3 { + return sum; + } + // These should be guarded so they don't fire after return + sum += i; + i += 1; + } + sum + } + } + "#}, + &expect![[r#" + // namespace Test + operation Main() : Int { + body { + mutable __has_returned : Bool = false; + mutable __ret_val : Int = 0; + mutable sum : Int = 0; + mutable i : Int = 0; + while not __has_returned and i < 10 { + if i == 3 { + { + __ret_val = sum; + __has_returned = true; + }; + } + + if not __has_returned { + sum += i; + }; + if not __has_returned { + i += 1; + }; + } + + let __trailing_result : Int = sum; + if __has_returned __ret_val else __trailing_result + } + } + // entry + Main() + "#]], + ); +} + +#[test] +fn if_expr_init_with_while_return_uses_flag_strategy() { + check_no_returns_q( + indoc! {r#" + namespace Test { + function Main() : Int { + let x = if true { + mutable i = 0; + while i < 5 { + if i == 3 { + return 42; + } + i += 1; + } + 0 + } else { + 1 + }; + x + } + } + "#}, + &expect![[r#" + // namespace Test + function Main() : Int { + body { + mutable __has_returned : Bool = false; + mutable __ret_val : Int = 0; + let x : Int = if true { + mutable i : Int = 0; + while not __has_returned and i < 5 { + if i == 3 { + { + __ret_val = 42; + __has_returned = true; + }; + } + + if not __has_returned { + i += 1; + }; + } + + 0 + } else { + 1 + }; + let __trailing_result : Int = x; + if __has_returned __ret_val else __trailing_result + } + } + // entry + Main() + "#]], + ); +} + +#[test] +fn flag_strategy_guards_local_after_return() { + // A Local statement following a return-bearing statement must be + // guarded by rewriting the initializer, not wrapping the whole Local. + check_no_returns_q( + indoc! {r#" + namespace Test { + function Main() : Int { + mutable i = 0; + while i < 5 { + if i == 3 { + return i; + } + let y = i * 2; + i += 1; + } + -1 + } + } + "#}, + &expect![[r#" + // namespace Test + function Main() : Int { + body { + mutable __has_returned : Bool = false; + mutable __ret_val : Int = 0; + mutable i : Int = 0; + while not __has_returned and i < 5 { + if i == 3 { + { + __ret_val = i; + __has_returned = true; + }; + } + + let y : Int = if not __has_returned { + i * 2 + } else { + 0 + }; + if not __has_returned { + i += 1; + }; + } + + let __trailing_result : Int = -1; + if __has_returned __ret_val else __trailing_result + } + } + // entry + Main() + "#]], + ); +} diff --git a/source/compiler/qsc_fir_transforms/src/return_unify/tests/idempotency.rs b/source/compiler/qsc_fir_transforms/src/return_unify/tests/idempotency.rs new file mode 100644 index 0000000000..ee5e992c2f --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/return_unify/tests/idempotency.rs @@ -0,0 +1,169 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use super::*; + +#[test] +fn idempotency_no_return() { + // No returns at all — structured strategy not triggered. + check_idempotency(indoc! {r#" + namespace Test { + function Main() : Int { + 42 + } + } + "#}); +} + +#[test] +fn idempotency_simple_guard_clause() { + // Single guard clause — structured (if-else) strategy. + check_idempotency(indoc! {r#" + namespace Test { + function Main() : Int { + if true { + return 1; + } + 0 + } + } + "#}); +} + +#[test] +fn idempotency_nested_if_else_returns() { + // Multiple branches with returns — structured strategy. + check_idempotency(indoc! {r#" + namespace Test { + function Main() : Int { + if true { + return 1; + } elif false { + return 2; + } else { + return 3; + } + } + } + "#}); +} + +#[test] +fn idempotency_while_loop_return() { + // Return inside while loop — flag strategy. + check_idempotency(indoc! {r#" + namespace Test { + function Main() : Int { + mutable i = 0; + while i < 10 { + if i == 5 { + return i; + } + i += 1; + } + i + } + } + "#}); +} + +#[test] +fn idempotency_for_loop_return() { + // Return inside for loop — flag strategy. + check_idempotency(indoc! {r#" + namespace Test { + function Main() : Int { + for i in 0..9 { + if i == 5 { + return i; + } + } + -1 + } + } + "#}); +} + +#[test] +fn idempotency_nested_blocks_with_return() { + // Return inside nested block — tests block normalization idempotency. + check_idempotency(indoc! {r#" + namespace Test { + function Main() : Int { + let x = { + if true { + return 1; + } + 2 + }; + x + } + } + "#}); +} + +#[test] +fn idempotency_unit_return() { + // Unit-typed early return. + check_idempotency(indoc! {r#" + namespace Test { + function Main() : Unit { + if true { + return (); + } + } + } + "#}); +} + +#[test] +fn idempotency_tuple_return() { + // Tuple-typed return — structured strategy. + check_idempotency(indoc! {r#" + namespace Test { + function Main() : (Int, Bool) { + if true { + return (1, true); + } + (0, false) + } + } + "#}); +} + +#[test] +fn idempotency_string_return_flag_strategy() { + // String-typed return in while loop — flag strategy. + check_idempotency(indoc! {r#" + namespace Test { + function Main() : String { + mutable i = 0; + while i < 3 { + if i == 1 { + return "found"; + } + i += 1; + } + "not found" + } + } + "#}); +} + +#[test] +fn idempotency_leaky_if_flag_strategy() { + // Leaky nested-if pattern — flag strategy with non-trivial guarding. + check_idempotency(indoc! {r#" + namespace Test { + function Main() : Int { + if true { + if false { + return 1; + } + return 2; + } + 3 + } + } + "#}); +} diff --git a/source/compiler/qsc_fir_transforms/src/return_unify/tests/qubit_release.rs b/source/compiler/qsc_fir_transforms/src/return_unify/tests/qubit_release.rs new file mode 100644 index 0000000000..0d11ca6fd5 --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/return_unify/tests/qubit_release.rs @@ -0,0 +1,313 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use super::*; + +#[test] +fn no_release_hoist_path_local_release_all_branches_return_keeps_branch_releases() { + let source = indoc! {r#" + namespace Test { + operation Foo(flag : Bool) : Int { + use q = Qubit(); + if flag { + return 1; + } else { + return 0; + } + } + + @EntryPoint() + operation Main() : Int { + Foo(true) + } + } + "#}; + + let result = compile_no_hoist_return_unified(source); + assert_path_local_releases_without_unconditional_suffix(&result, "Foo"); + check_no_hoist_semantic_equivalence(source); +} + +#[test] +fn no_release_hoist_path_local_release_guard_return_threads_fallthrough_release() { + let source = indoc! {r#" + namespace Test { + operation Foo(flag : Bool) : Int { + use q = Qubit(); + if flag { + return 1; + } + Reset(q); + 0 + } + + @EntryPoint() + operation Main() : Int { + Foo(true) + } + } + "#}; + + let result = compile_no_hoist_return_unified(source); + assert_path_local_releases_without_unconditional_suffix(&result, "Foo"); + check_no_hoist_semantic_equivalence(source); +} + +#[test] +fn no_release_hoist_path_local_release_nested_qubit_scopes_stay_path_local() { + let source = indoc! {r#" + namespace Test { + operation Foo(flag : Bool) : Int { + use outer = Qubit(); + if flag { + use inner = Qubit(); + Reset(inner); + Reset(outer); + return 1; + } + Reset(outer); + 0 + } + + @EntryPoint() + operation Main() : Int { + Foo(true) + } + } + "#}; + + let result = compile_no_hoist_return_unified(source); + assert_path_local_releases_without_unconditional_suffix(&result, "Foo"); + check_no_hoist_semantic_equivalence(source); +} + +#[test] +fn no_release_hoist_path_local_release_qubit_arrays_stay_path_local() { + let source = indoc! {r#" + namespace Test { + operation Foo(flag : Bool) : Int { + use qs = Qubit[2]; + if flag { + return 1; + } + 0 + } + + @EntryPoint() + operation Main() : Int { + Foo(true) + } + } + "#}; + + let result = compile_no_hoist_return_unified(source); + assert_path_local_releases_without_unconditional_suffix(&result, "Foo"); + check_no_hoist_semantic_equivalence(source); +} + +#[test] +fn no_release_hoist_flag_strategy_guards_loop_scope_release_continuation() { + let source = indoc! {r#" + namespace Test { + @EntryPoint() + operation Main() : Int { + mutable i = 0; + while i < 5 { + use q = Qubit(); + if i == 3 { + return i; + } + i += 1; + } + -1 + } + } + "#}; + + let result = compile_no_hoist_return_unified(source); + assert_guarded_release_continuation(&result, "Main"); + check_no_hoist_semantic_equivalence(source); +} + +#[test] +fn no_release_hoist_flag_strategy_guards_body_scope_release_continuation() { + let source = indoc! {r#" + namespace Test { + @EntryPoint() + operation Main() : Int { + use q = Qubit(); + mutable i = 0; + while i < 10 { + if i == 3 { + Reset(q); + return i; + } + i += 1; + } + Reset(q); + 0 + } + } + "#}; + + let result = compile_no_hoist_return_unified(source); + assert_guarded_release_continuation(&result, "Main"); + check_no_hoist_semantic_equivalence(source); +} + +/// Return-statement classification: `classify_return_stmt` maps +/// `StmtKind::Expr(Return(inner))` and `StmtKind::Semi(Return(inner))` +/// to the same `BareReturn(inner)` by design. Two callables differing +/// only in the trailing `;` must produce structurally identical +/// post-`return_unify` bodies. + +#[test] +fn qubit_release_guarded_in_for_loop_with_early_return() { + check_no_returns_q( + indoc! {r#" + namespace Test { + operation Main() : Int { + mutable result = 0; + for i in 0..4 { + use q = Qubit(); + if i == 3 { + result = i; + return result; + } + } + result + } + } + "#}, + &expect![[r#" + // namespace Test + operation Main() : Int { + body { + mutable __has_returned : Bool = false; + mutable __ret_val : Int = 0; + mutable result : Int = 0; + { + let + @range_id_41 : Range = 0..4; + mutable + @index_id_44 : Int = + @range_id_41::Start; + let + @step_id_49 : Int = + @range_id_41::Step; + let + @end_id_54 : Int = + @range_id_41::End; + while not __has_returned and + @step_id_49 > 0 and + @index_id_44 <= + @end_id_54 or + @step_id_49 < 0 and + @index_id_44 >= + @end_id_54 { + let i : Int = + @index_id_44; + let q : Qubit = __quantum__rt__qubit_allocate(); + if i == 3 { + result = i; + { + let + @generated_ident_89 : Int = result; + __quantum__rt__qubit_release(q); + { + __ret_val = + @generated_ident_89; + __has_returned = true; + }; + }; + } + + if not __has_returned { + @index_id_44 += + @step_id_49; + }; + if not __has_returned { + __quantum__rt__qubit_release(q); + }; + } + + } + + let __trailing_result : Int = result; + if __has_returned __ret_val else __trailing_result + } + } + // entry + Main() + "#]], + ); +} + +#[test] +fn body_level_qubit_release_guarded_with_while_return() { + check_no_returns_q( + indoc! {r#" + namespace Test { + operation Main() : Int { + use q = Qubit(); + mutable i = 0; + while i < 10 { + if i == 3 { + Reset(q); + return i; + } + i += 1; + } + Reset(q); + 0 + } + } + "#}, + &expect![[r#" + // namespace Test + operation Main() : Int { + body { + mutable __has_returned : Bool = false; + mutable __ret_val : Int = 0; + let q : Qubit = __quantum__rt__qubit_allocate(); + mutable i : Int = 0; + while not __has_returned and i < 10 { + if i == 3 { + Reset(q); + { + let + @generated_ident_52 : Int = i; + __quantum__rt__qubit_release(q); + { + __ret_val = + @generated_ident_52; + __has_returned = true; + }; + }; + } + + if not __has_returned { + i += 1; + }; + } + + if not __has_returned { + Reset(q); + }; + let + @generated_ident_64 : Int = { + 0 + }; + if not __has_returned { + __quantum__rt__qubit_release(q); + }; + let __trailing_result : Int = + @generated_ident_64; + if __has_returned __ret_val else __trailing_result + } + } + // entry + Main() + "#]], + ); +} diff --git a/source/compiler/qsc_fir_transforms/src/return_unify/tests/regressions.rs b/source/compiler/qsc_fir_transforms/src/return_unify/tests/regressions.rs new file mode 100644 index 0000000000..fda9f3279c --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/return_unify/tests/regressions.rs @@ -0,0 +1,257 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use super::*; + +#[test] +fn differential_triple_nested_if_return_known_bug() { + check_semantic_equivalence(indoc! {r#" + namespace Test { + function Main() : Int { + if 0 > 0 { + if 0 > 0 { + if 0 > 0 { return 1; } + return 0; + } + 0 + } else { + return 2; + } + } + } + "#}); +} + +/// Simpler variant: return only in else branch with false condition. +/// Checks whether the bug requires deep nesting or just else-return under +/// a false condition. + +#[test] +fn differential_else_return_false_condition() { + check_semantic_equivalence(indoc! {r#" + namespace Test { + function Main() : Int { + if 0 > 0 { 42 } else { return 0; } + } + } + "#}); +} + +/// Structural snapshot: verifies the bind-then-check pattern in the FIR +/// output for the triple-nested if-return case. The trailing +/// expression is bound to `__trailing_result` before the `__has_returned` +/// flag is checked. + +#[test] +fn triple_nested_if_return_with_else_return() { + check_no_returns_q( + indoc! {r#" + namespace Test { + function Main() : Int { + if 0 > 0 { + if 0 > 0 { + if 0 > 0 { return 1; } + return 0; + } + 0 + } else { + return 2; + } + } + } + "#}, + &expect![[r#" + // namespace Test + function Main() : Int { + body { + mutable __has_returned : Bool = false; + mutable __ret_val : Int = 0; + let __trailing_result : Int = if 0 > 0 { + if 0 > 0 { + if 0 > 0 { + { + __ret_val = 1; + __has_returned = true; + }; + } + + if not __has_returned { + { + __ret_val = 0; + __has_returned = true; + }; + }; + } + + 0 + } else { + { + __ret_val = 2; + __has_returned = true; + }; + }; + if __has_returned __ret_val else __trailing_result + } + } + // entry + Main() + "#]], + ); +} + +/// Semantic companion for `triple_nested_if_return_with_else_return`. + +#[test] +fn structured_strategy_preserves_releases_on_all_paths() { + let source = indoc! {r#" + namespace Test { + operation Foo(flag : Bool) : Int { + use q = Qubit(); + if flag { + return 1; + } + 0 + } + + @EntryPoint() + operation Main() : Int { + Foo(true) + } + } + "#}; + + let (store, pkg_id) = compile_return_unified(source); + let package = store.get(pkg_id); + + let body_block_id = find_body_block_id(package, "Foo"); + let body_block = package.get_block(body_block_id); + + let release_callables = collect_release_callables(&store); + let release_indices = body_block + .stmts + .iter() + .enumerate() + .filter_map(|(index, &stmt_id)| { + is_release_call_test(package, stmt_id, &release_callables).then_some(index) + }) + .collect::>(); + assert!( + release_indices.is_empty(), + "structured strategy should not keep a top-level release suffix after path-local releases" + ); + + let has_path_local_release = body_block.stmts.iter().any(|&stmt_id| { + stmt_contains_path_local_release_value(package, stmt_id, &release_callables) + }); + assert!( + has_path_local_release, + "structured strategy must preserve release calls inside value-producing paths" + ); + + let trailing_stmt_id = *body_block + .stmts + .last() + .expect("Foo body should not be empty"); + let StmtKind::Expr(trailing_expr_id) = package.get_stmt(trailing_stmt_id).kind else { + panic!("Foo body should end with a trailing expression"); + }; + assert_eq!( + package.get_expr(trailing_expr_id).ty, + Ty::Prim(Prim::Int), + "Foo body should keep an Int-producing trailing expression" + ); + + check_semantic_equivalence(source); +} + +#[test] +fn if_both_return_release_suffix_before_after_qsharp() { + check_pre_fir_transforms_to_return_unify_q( + indoc! {r#" + namespace Test { + operation Foo(flag : Bool) : Int { + use q = Qubit(); + if flag { + return 1; + } else { + return 0; + } + } + + @EntryPoint() + operation Main() : Int { + Foo(true) + } + } + "#}, + &expect![[r#" + // before fir transforms + // namespace Test + operation Foo(flag : Bool) : Int { + body { + let q : Qubit = __quantum__rt__qubit_allocate(); + let + @generated_ident_65 : Unit = if flag { + { + let + @generated_ident_41 : Int = 1; + __quantum__rt__qubit_release(q); + return + @generated_ident_41; + }; + } else { + { + let + @generated_ident_53 : Int = 0; + __quantum__rt__qubit_release(q); + return + @generated_ident_53; + }; + }; + __quantum__rt__qubit_release(q); + @generated_ident_65 + } + } + operation Main() : Int { + body { + Foo(true) + } + } + // entry + Main() + + // post return_unify + // namespace Test + operation Foo(flag : Bool) : Int { + body { + let q : Qubit = __quantum__rt__qubit_allocate(); + if flag { + { + let + @generated_ident_41 : Int = 1; + __quantum__rt__qubit_release(q); + @generated_ident_41 + } + + } else { + { + let + @generated_ident_53 : Int = 0; + __quantum__rt__qubit_release(q); + @generated_ident_53 + } + + } + + } + } + operation Main() : Int { + body { + Foo(true) + } + } + // entry + Main() + "#]], + ); +} diff --git a/source/compiler/qsc_fir_transforms/src/return_unify/tests/semantic.rs b/source/compiler/qsc_fir_transforms/src/return_unify/tests/semantic.rs new file mode 100644 index 0000000000..c26c205e15 --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/return_unify/tests/semantic.rs @@ -0,0 +1,869 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use super::*; + +#[test] +fn outer_return_wrapping_if_with_stmt_return_in_else_does_not_loop_semantic() { + check_semantic_equivalence(indoc! {r#" + namespace Test { + import Std.Measurement.*; + + @EntryPoint() + operation Main() : Int { + use q = Qubit(); + return if M(q) == One { + 1 + } else { + return M(q) == One ? 0 | 1; + }; + } + } + "#}); +} + +/// Evaluates the entry exec graph of the given FIR store with a fixed +/// simulator seed for determinism. Returns `Ok(value)` on success, or +/// `Err(error_string)` on evaluation failure. + +#[test] +fn while_divzero_condition_short_circuits_after_return_semantic() { + check_semantic_equivalence(indoc! {r#" + namespace Test { + function Main() : Int { + mutable i = 0; + while 10 / (3 - i) > 0 { + i += 1; + if i == 3 { + return i; + } + } + -1 + } + } + "#}); +} + +#[test] +fn while_mixed_condition_and_body_returns_semantic() { + check_semantic_equivalence(indoc! {r#" + namespace Test { + function Main() : Int { + mutable i = 0; + while (if i > 5 { return 99; } else { true }) { + i += 1; + if i == 3 { + return i; + } + } + -1 + } + } + "#}); +} + +#[test] +fn bare_return_with_dead_code_semantic() { + // Classical version: exercises the same bare-return + dead-code + // truncation path without qubit scope asymmetry. + check_semantic_equivalence(indoc! {r#" + namespace Test { + function Main() : Int { + let x = 1; + return 42; + let y = x + 1; + y + 2 + } + } + "#}); +} + +#[test] +fn return_after_dynamic_branch_with_dead_code_semantic() { + check_semantic_equivalence(indoc! {r#" + namespace Test { + operation Main() : Unit { + use q = Qubit(); + if M(q) == One { + X(q); + } else { + H(q); + } + H(q); + return (); + Y(q); + } + } + "#}); +} + +#[test] +fn nested_if_with_returns_at_different_levels_semantic() { + check_semantic_equivalence(indoc! {r#" + namespace Test { + function Main() : Int { + if true { + if false { + return 1; + } + return 2; + } + 3 + } + } + "#}); +} + +#[test] +fn nested_block_middle_of_block_fix_semantic() { + check_semantic_equivalence(indoc! {r#" + namespace Test { + function Main() : Int { + let c = true; + let _unused = { + if c { return 1; } + 2 + }; + let y = 3; + y + } + } + "#}); +} + +#[test] +fn hoist_return_in_range_endpoint_semantic() { + check_semantic_equivalence(indoc! {r#" + namespace Test { + function Main() : Int { + mutable sum = 0; + for i in 0..(return 5) { + sum += i; + } + sum + } + } + "#}); +} + +#[test] +fn return_bool_in_dynamic_branch_semantic() { + check_semantic_equivalence(indoc! {r#" + namespace Test { + operation Main() : Bool { + use q = Qubit(); + if M(q) == One { + return true; + } + false + } + } + "#}); +} + +#[test] +fn return_unit_after_side_effects_semantic() { + // Classical version: exercises the same early-return-unit + remaining + // side-effects path without qubit scope asymmetry. + check_semantic_equivalence(indoc! {r#" + namespace Test { + function Main() : Unit { + mutable x = 0; + if x == 0 { + x = 1; + return (); + } + x = 2; + } + } + "#}); +} + +#[test] +fn both_branches_return_with_qubit_scope_semantic() { + check_semantic_equivalence(indoc! {r#" + namespace Test { + operation Main() : Bool { + use q = Qubit(); + let r = M(q); + Reset(q); + if r == One { + return true; + } else { + return false; + } + } + } + "#}); +} + +#[test] +fn for_loop_with_early_return_semantic() { + check_semantic_equivalence(indoc! {r#" + namespace Test { + function Main() : Int { + for i in 0..10 { + if i == 5 { + return i; + } + } + -1 + } + } + "#}); +} + +#[test] +fn deeply_nested_block_with_return_semantic() { + check_semantic_equivalence(indoc! {r#" + namespace Test { + function Main() : Int { + let x = { + if true { + return 10; + } + 5 + }; + x + } + } + "#}); +} + +#[test] +fn multiple_returns_in_helper_function_semantic() { + check_semantic_equivalence(indoc! {r#" + namespace Test { + function Classify(x : Int) : Int { + if x > 0 { + return 1; + } + if x < 0 { + return -1; + } + 0 + } + function Main() : Int { + Classify(5) + } + } + "#}); +} + +#[test] +fn guard_clause_with_existing_else_and_remaining_semantic() { + check_semantic_equivalence(indoc! {r#" + namespace Test { + function Main() : Int { + if true { + return 1; + } else { + let _ = 0; + } + 2 + } + } + "#}); +} + +#[test] +fn return_tuple_value_semantic() { + check_semantic_equivalence(indoc! {r#" + namespace Test { + function Main() : (Int, Bool) { + if true { + return (1, true); + } + (0, false) + } + } + "#}); +} + +// Recursive function with early return + +#[test] +fn recursive_function_with_return_semantic() { + check_semantic_equivalence(indoc! {r#" + namespace Test { + function Factorial(n : Int) : Int { + if n <= 1 { + return 1; + } + n * Factorial(n - 1) + } + function Main() : Int { + Factorial(5) + } + } + "#}); +} + +// Tuple return + while + nested if + +#[test] +fn tuple_return_in_while_with_nested_if_semantic() { + check_semantic_equivalence(indoc! {r#" + namespace Test { + function Main() : (Int, Bool) { + mutable i = 0; + while i < 10 { + if i > 5 { + if i == 7 { + return (i, true); + } + } + i += 1; + } + (-1, false) + } + } + "#}); +} + +// All 4 specializations with flag strategy (for-loop desugar) + +#[test] +fn qubit_alloc_scope_with_flag_strategy_semantic() { + check_semantic_equivalence(indoc! {r#" + namespace Test { + operation Main() : Int { + mutable i = 0; + while i < 5 { + use q = Qubit(); + if i == 3 { + return i; + } + i += 1; + } + -1 + } + } + "#}); +} + +// repeat-until + return (desugared to while at HIR) + +#[test] +fn repeat_until_with_return_semantic() { + check_semantic_equivalence(indoc! {r#" + namespace Test { + operation Main() : Int { + mutable result = 0; + mutable attempt = 0; + repeat { + if attempt > 3 { + return -1; + } + attempt += 1; + result = attempt * 2; + } until result > 5; + result + } + } + "#}); +} + +// fail + return in same control flow + +#[test] +fn while_body_side_effect_guarded_after_return_semantic() { + check_semantic_equivalence(indoc! {r#" + namespace Test { + operation Main() : Int { + mutable sum = 0; + mutable i = 0; + while i < 10 { + if i == 3 { + return sum; + } + sum += i; + i += 1; + } + sum + } + } + "#}); +} + +// Qubit alloc scope + flag strategy — release continuations are guarded + +#[test] +fn qubit_release_guarded_in_for_loop_with_early_return_semantic() { + check_semantic_equivalence(indoc! {r#" + namespace Test { + operation Main() : Int { + mutable result = 0; + for i in 0..4 { + use q = Qubit(); + if i == 3 { + result = i; + return result; + } + } + result + } + } + "#}); +} + +#[test] +fn body_level_qubit_release_guarded_with_while_return_semantic() { + check_semantic_equivalence(indoc! {r#" + namespace Test { + operation Main() : Int { + use q = Qubit(); + mutable i = 0; + while i < 10 { + if i == 3 { + Reset(q); + return i; + } + i += 1; + } + Reset(q); + 0 + } + } + "#}); +} + +#[test] +fn if_expr_init_with_while_return_uses_flag_strategy_semantic() { + check_semantic_equivalence(indoc! {r#" + namespace Test { + function Main() : Int { + let x = if true { + mutable i = 0; + while i < 5 { + if i == 3 { + return 42; + } + i += 1; + } + 0 + } else { + 1 + }; + x + } + } + "#}); +} + +#[test] +fn simple_if_expr_init_with_return_stays_structured_semantic() { + check_semantic_equivalence(indoc! {r#" + namespace Test { + function Main() : Int { + if true { + return 10; + } + let x = if false { return 20; } else { 30 }; + x + } + } + "#}); +} + +#[test] +fn flag_strategy_guards_local_after_return_semantic() { + check_semantic_equivalence(indoc! {r#" + namespace Test { + function Main() : Int { + mutable i = 0; + while i < 5 { + if i == 3 { + return i; + } + let y = i * 2; + i += 1; + } + -1 + } + } + "#}); +} + +// Tests excluded from semantic comparison (no `_semantic` companion): +// +// Error-contract tests (test panics/errors, not values): +// - guard_stmt_with_flag_rejects_non_unit_expr_stmt (#[should_panic]) +// - flag_trailing_without_trailing_expr_rejects_non_unit_contract (#[should_panic]) +// - unsupported_return_slot_default_in_flag_strategy_produces_error (expects error list) +// - unsupported_guarded_local_default_in_flag_strategy_is_explicit_contract (#[should_panic]) +// - qubit_return_in_while_produces_error (expects error list) +// +// Specialization tests (Adj/Ctl, no single entry point output): +// - explicit_specialization_bodies_are_return_unified +// - simulatable_intrinsic_body_is_return_unified +// - all_four_specializations_with_return_in_loop +// +// Arrow-typed return tests (blocked by defunctionalization limitation): +// - arrow_typed_return_in_structured_path +// +// No-return or identity tests (no transform to validate): +// - no_op_function_without_returns +// - already_normalized_idempotency +// - lowered_reachable_callables_do_not_emit_while_local_initializers (no returns in source) +// +// Non-standard compilation flow (synthetic FIR or direct transform call): +// - synthetic_while_local_initializer_shape_still_eliminates_returns +// - nested_block_with_while_return_not_transformable_by_if_else +// +// Structural comparison only (compares two sources, not runtime values): +// - classify_semi_return_and_expr_return_produce_same_shape + +#[test] +fn single_trailing_return_semantic() { + check_semantic_equivalence(indoc! {r#" + namespace Test { + function Main() : Int { + return 42; + } + } + "#}); +} + +#[test] +fn guard_clause_pattern_semantic() { + check_semantic_equivalence(indoc! {r#" + namespace Test { + function Main() : Int { + if true { + return 1; + } + 0 + } + } + "#}); +} + +#[test] +fn multiple_guard_clauses_semantic() { + check_semantic_equivalence(indoc! {r#" + namespace Test { + function Main() : Int { + if true { + return 1; + } + if false { + return 2; + } + if true { + return 3; + } + 0 + } + } + "#}); +} + +#[test] +fn both_branches_return_semantic() { + check_semantic_equivalence(indoc! {r#" + namespace Test { + function Main() : Int { + if true { + return 1; + } else { + return 2; + } + } + } + "#}); +} + +#[test] +fn return_in_nested_block_semantic() { + check_semantic_equivalence(indoc! {r#" + namespace Test { + function Main() : Int { + { + { + return 10; + } + }; + 5 + } + } + "#}); +} + +#[test] +fn return_inside_while_loop_semantic() { + check_semantic_equivalence(indoc! {r#" + namespace Test { + function Main() : Int { + mutable i = 0; + while i < 10 { + if i == 5 { + return i; + } + i += 1; + } + -1 + } + } + "#}); +} + +#[test] +fn while_return_tuple_value_uses_flag_fallback_semantic() { + check_semantic_equivalence(indoc! {r#" + namespace Test { + function Main() : (Int, Bool) { + mutable i = 0; + while i < 3 { + if i == 1 { + return (i, true); + } + i += 1; + } + (-1, false) + } + } + "#}); +} + +#[test] +fn while_return_array_value_uses_flag_fallback_semantic() { + check_semantic_equivalence(indoc! {r#" + namespace Test { + function Main() : Int[] { + mutable i = 0; + while i < 3 { + if i == 1 { + return [i, i + 1]; + } + i += 1; + } + [] + } + } + "#}); +} + +#[test] +fn while_local_initializer_if_return_is_rewritten_by_flag_strategy_semantic() { + check_semantic_equivalence(indoc! {r#" + namespace Test { + function Add(a : Int, b : Int) : Int { a + b } + function Main() : Int { + mutable i = 0; + while i < 3 { + let _ = if i == 1 { + Add((return 42), i) + }; + i += 1; + } + i + 5 + } + } + "#}); +} + +#[test] +fn while_local_initializer_if_else_return_preserves_fallthrough_tail_semantic() { + check_semantic_equivalence(indoc! {r#" + namespace Test { + function Add(a : Int, b : Int) : Int { a + b } + function Main() : Int { + mutable i = 0; + while i < 3 { + let x = if i == 1 { + Add((return 7), i) + } else { + i + 10 + }; + i += x; + } + let tail = i + 5; + tail + } + } + "#}); +} + +#[test] +fn nested_loop_exit_convergence_is_guarded_by_flag_semantic() { + check_semantic_equivalence(indoc! {r#" + namespace Test { + function Main() : Int { + mutable outer = 0; + mutable inner = 0; + while outer < 2 { + while inner < 2 { + if inner == 1 { + return outer + inner; + } + inner += 1; + } + outer += 1; + inner = 0; + } + -1 + } + } + "#}); +} + +#[test] +fn unit_returning_with_return_semantic() { + check_semantic_equivalence(indoc! {r#" + namespace Test { + function Main() : Unit { + if true { + return (); + } + } + } + "#}); +} + +#[test] +fn while_body_call_arg_return_keeps_loop_before_trailing_merge_semantic() { + check_semantic_equivalence(indoc! {r#" + namespace Test { + function Add(a : Int, b : Int) : Int { a + b } + + function Main() : Int { + mutable i = 0; + while i < 3 { + let _ = Add((return 42), 2); + i += 1; + } + -1 + } + } + "#}); +} + +#[test] +fn return_value_is_complex_expression_semantic() { + check_semantic_equivalence(indoc! {r#" + namespace Test { + function Add(a : Int, b : Int) : Int { a + b } + function Main() : Int { + if true { + return Add(1, 2) + Add(3, 4); + } + 0 + } + } + "#}); +} + +#[test] +fn return_in_else_branch_only_semantic() { + check_semantic_equivalence(indoc! {r#" + namespace Test { + function Main() : Int { + if true { + 1 + } else { + return 2; + } + } + } + "#}); +} + +#[test] +fn range_return_default_in_flag_strategy_is_supported_semantic() { + check_semantic_equivalence(indoc! {r#" + namespace Test { + function Main() : Range { + mutable i = 0; + while i < 1 { + return 0..1; + } + 2..3 + } + } + "#}); +} + +#[test] +fn fail_and_return_in_same_control_flow_semantic() { + check_semantic_equivalence(indoc! {r#" + namespace Test { + function Main() : Int { + let c = true; + if c { + return 42; + } else { + fail "unreachable"; + } + } + } + "#}); +} + +// Quantum semantic companions (added after qubit-scope semantic fix) + +#[test] +fn nested_qubit_scope_return_updates_outer_block_type_semantic() { + check_semantic_equivalence(indoc! {r#" + namespace Test { + import Std.Measurement.*; + + operation Main() : Result { + use outer = Qubit() { + use qubit = Qubit() { + let result = MResetZ(qubit); + Reset(outer); + return result; + } + } + } + } + "#}); +} + +#[test] +fn early_return_in_qubit_array_scope_preserves_release_order_semantic() { + check_semantic_equivalence(indoc! {r#" + namespace Test { + operation Foo(flag : Bool) : Int { + use qs = Qubit[2]; + if flag { + return 1; + } + 0 + } + + operation Main() : Int { + Foo(true) + } + } + "#}); +} + +// Idempotency tests +// +// Verify that running `unify_returns` a second time on already-transformed +// FIR is a no-op: no new arena entries (blocks, stmts, exprs, pats) are +// allocated, and no errors are produced. + +/// Helper: compile through `return_unify`, then run `unify_returns` again and +/// assert that the package arenas are unchanged (no new IDs allocated). + +#[test] +fn triple_nested_if_return_with_else_return_semantic() { + check_semantic_equivalence(indoc! {r#" + namespace Test { + function Main() : Int { + if 0 > 0 { + if 0 > 0 { + if 0 > 0 { return 0; } + return 0; + } + 0 + } else { + return 0; + } + } + } + "#}); +} diff --git a/source/compiler/qsc_fir_transforms/src/return_unify/tests/structured_strategy.rs b/source/compiler/qsc_fir_transforms/src/return_unify/tests/structured_strategy.rs new file mode 100644 index 0000000000..85ce2684ae --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/return_unify/tests/structured_strategy.rs @@ -0,0 +1,1294 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use super::*; + +#[test] +fn no_op_function_without_returns() { + // A function with no return statements should pass through unchanged. + check_no_returns_q( + indoc! {r#" + namespace Test { + function Main() : Int { + let x = 1; + x + 2 + } + } + "#}, + &expect![[r#" + // namespace Test + function Main() : Int { + body { + let x : Int = 1; + x + 2 + } + } + // entry + Main() + "#]], + ); +} + +#[test] +fn single_trailing_return() { + // `return x;` as the last statement should be simplified to just `x`. + check_no_returns_q( + indoc! {r#" + namespace Test { + function Main() : Int { + return 42; + } + } + "#}, + &expect![[r#" + // namespace Test + function Main() : Int { + body { + 42 + } + } + // entry + Main() + "#]], + ); +} + +#[test] +fn guard_clause_pattern() { + // `if cond { return a; } b` → `if cond { a } else { b }` + check_no_returns_q( + indoc! {r#" + namespace Test { + function Main() : Int { + if true { + return 1; + } + 0 + } + } + "#}, + &expect![[r#" + // namespace Test + function Main() : Int { + body { + if true { + 1 + } else { + 0 + } + + } + } + // entry + Main() + "#]], + ); +} + +#[test] +fn multiple_guard_clauses() { + // Three sequential if-return → nested if-else chain. + check_no_returns_q( + indoc! {r#" + namespace Test { + function Main() : Int { + if true { + return 1; + } + if false { + return 2; + } + if true { + return 3; + } + 0 + } + } + "#}, + &expect![[r#" + // namespace Test + function Main() : Int { + body { + if true { + 1 + } else { + if false { + 2 + } else { + if true { + 3 + } else { + 0 + } + + } + + } + + } + } + // entry + Main() + "#]], + ); +} + +#[test] +fn both_branches_return() { + // `if cond { return a; } else { return b; }` → `if cond { a } else { b }` + check_no_returns_q( + indoc! {r#" + namespace Test { + function Main() : Int { + if true { + return 1; + } else { + return 2; + } + } + } + "#}, + &expect![[r#" + // namespace Test + function Main() : Int { + body { + if true { + 1 + } else { + 2 + } + + } + } + // entry + Main() + "#]], + ); +} + +#[test] +fn both_branches_return_with_qubit_scope() { + // Both branches return inside a qubit scope — tests interaction with + // `replace_qubit_allocation` which inserts release calls. + check_no_returns_q( + indoc! {r#" + namespace Test { + operation Main() : Bool { + use q = Qubit(); + let r = M(q); + Reset(q); + if r == One { + return true; + } else { + return false; + } + } + } + "#}, + &expect![[r#" + // namespace Test + operation Main() : Bool { + body { + let q : Qubit = __quantum__rt__qubit_allocate(); + let r : Result = M(q); + Reset(q); + if r == One { + { + let + @generated_ident_43 : Bool = true; + __quantum__rt__qubit_release(q); + @generated_ident_43 + } + + } else { + { + let + @generated_ident_55 : Bool = false; + __quantum__rt__qubit_release(q); + @generated_ident_55 + } + + } + + } + } + function Length(a : Pauli[]) : Int { + body intrinsic; + } + function Length(a : Qubit[]) : Int { + body intrinsic; + } + // entry + Main() + "#]], + ); +} + +#[test] +fn return_in_nested_block() { + // `{ { return x; } }` → `{ { x } }` + check_no_returns_q( + indoc! {r#" + namespace Test { + function Main() : Int { + { + { + return 10; + } + }; + 5 + } + } + "#}, + &expect![[r#" + // namespace Test + function Main() : Int { + body { + { + { + 10 + } + + } + + } + } + // entry + Main() + "#]], + ); +} + +#[test] +fn unit_returning_with_return() { + // `return ();` patterns in Unit-returning operations. + check_no_returns_q( + indoc! {r#" + namespace Test { + function Main() : Unit { + if true { + return (); + } + } + } + "#}, + &expect![[r#" + // namespace Test + function Main() : Unit { + body { + if true { + () + } else {} + + } + } + // entry + Main() + "#]], + ); +} + +#[test] +fn explicit_specialization_bodies_are_return_unified() { + check_structure( + indoc! {r#" + namespace Test { + operation Foo(n : Int, q : Qubit) : Unit is Adj + Ctl { + body ... { + if n == 0 { + return (); + } + H(q); + } + adjoint ... { + if n == 1 { + return (); + } + X(q); + } + controlled (ctls, ...) { + if Length(ctls) == 0 { + return (); + } + Controlled H(ctls, q); + } + controlled adjoint (ctls, ...) { + if Length(ctls) == 1 { + return (); + } + Controlled X(ctls, q); + } + } + + @EntryPoint() + operation Main() : Unit { + use q = Qubit(); + Foo(1, q); + } + } + "#}, + &["Foo", "Main"], + &expect![[r#" +callable Foo: input_ty=(Int, Qubit), output_ty=Unit + body: block_ty=Unit + [0] Expr If(cond=BinOp(Eq)[ty=Bool], then=Block[ty=Unit], else=Block[ty=Unit]) + adj: block_ty=Unit + [0] Expr If(cond=BinOp(Eq)[ty=Bool], then=Block[ty=Unit], else=Block[ty=Unit]) + ctl: block_ty=Unit + [0] Expr If(cond=BinOp(Eq)[ty=Bool], then=Block[ty=Unit], else=Block[ty=Unit]) + ctl_adj: block_ty=Unit + [0] Expr If(cond=BinOp(Eq)[ty=Bool], then=Block[ty=Unit], else=Block[ty=Unit]) +callable Main: input_ty=Unit, output_ty=Unit + body: block_ty=Unit + [0] Local(Immutable, q: Qubit): Call[ty=Qubit] + [1] Semi Call[ty=Unit] + [2] Semi Call[ty=Unit]"#]], + ); +} + +#[test] +fn simulatable_intrinsic_body_is_return_unified() { + check_structure( + indoc! {r#" + namespace Test { + @SimulatableIntrinsic() + operation Foo() : Int { + mutable i = 0; + while i < 3 { + if i == 1 { + return i; + } + i += 1; + } + -1 + } + + @EntryPoint() + operation Main() : Int { + Foo() + } + } + "#}, + &["Foo", "Main"], + &expect![[r#" + callable Foo: input_ty=Unit, output_ty=Int + simulatable: block_ty=Int + [0] Local(Mutable, __has_returned: Bool): Lit(Bool(false)) + [1] Local(Mutable, __ret_val: Int): Lit(Int(0)) + [2] Local(Mutable, i: Int): Lit(Int(0)) + [3] Expr While[ty=Unit] + [4] Local(Immutable, __trailing_result: Int): UnOp(Neg)[ty=Int] + [5] Expr If(cond=Var[ty=Bool], then=Var[ty=Int], else=Var[ty=Int]) + callable Main: input_ty=Unit, output_ty=Int + body: block_ty=Int + [0] Expr Call[ty=Int]"#]], + ); +} + +#[test] +fn already_normalized_idempotency() { + // Running on already-normalized code (no returns) produces no changes. + let source = indoc! {r#" + namespace Test { + function Main() : Int { + if true { + 1 + } else { + 2 + } + } + } + "#}; + // Snapshot pins the stable output; any divergence fails the check. + check_no_returns_q( + source, + &expect![[r#" + // namespace Test + function Main() : Int { + body { + if true { + 1 + } else { + 2 + } + + } + } + // entry + Main() + "#]], + ); +} + +#[test] +fn return_value_is_complex_expression() { + // `return f(x) + g(y);` style complex expression. + check_no_returns_q( + indoc! {r#" + namespace Test { + function Add(a : Int, b : Int) : Int { a + b } + function Main() : Int { + if true { + return Add(1, 2) + Add(3, 4); + } + 0 + } + } + "#}, + &expect![[r#" + // namespace Test + function Add(a : Int, b : Int) : Int { + body { + a + b + } + } + function Main() : Int { + body { + if true { + Add(1, 2) + Add(3, 4) + } else { + 0 + } + + } + } + // entry + Main() + "#]], + ); +} + +#[test] +fn return_in_else_branch_only() { + check_no_returns_q( + indoc! {r#" + namespace Test { + function Main() : Int { + if true { + 1 + } else { + return 2; + } + } + } + "#}, + &expect![[r#" + // namespace Test + function Main() : Int { + body { + if not true { + 2 + } else { + 1 + } + + } + } + // entry + Main() + "#]], + ); +} + +#[test] +fn return_bool_in_dynamic_branch() { + // Quantum operation with dynamic branch using measurement. + check_no_returns_q( + indoc! {r#" + namespace Test { + operation Main() : Bool { + use q = Qubit(); + if M(q) == One { + return true; + } + false + } + } + "#}, + &expect![[r#" + // namespace Test + operation Main() : Bool { + body { + let q : Qubit = __quantum__rt__qubit_allocate(); + if M(q) == One { + { + let + @generated_ident_32 : Bool = true; + __quantum__rt__qubit_release(q); + @generated_ident_32 + } + + } else { + let + @generated_ident_44 : Bool = false; + __quantum__rt__qubit_release(q); + @generated_ident_44 + } + + } + } + function Length(a : Pauli[]) : Int { + body intrinsic; + } + function Length(a : Qubit[]) : Int { + body intrinsic; + } + // entry + Main() + "#]], + ); +} + +#[test] +fn multiple_returns_in_helper_function() { + // Helper function called from entry point with multiple returns. + check_no_returns_q( + indoc! {r#" + namespace Test { + function Classify(x : Int) : Int { + if x > 0 { + return 1; + } + if x < 0 { + return -1; + } + 0 + } + function Main() : Int { + Classify(5) + } + } + "#}, + &expect![[r#" + // namespace Test + function Classify(x : Int) : Int { + body { + if x > 0 { + 1 + } else { + if x < 0 { + -1 + } else { + 0 + } + + } + + } + } + function Main() : Int { + body { + Classify(5) + } + } + // entry + Main() + "#]], + ); +} + +#[test] +fn return_unit_after_side_effects() { + check_no_returns_q( + indoc! {r#" + namespace Test { + operation Main() : Unit { + use q = Qubit(); + H(q); + if M(q) == One { + X(q); + return (); + } + Y(q); + } + } + "#}, + &expect![[r#" + // namespace Test + operation Main() : Unit { + body { + let q : Qubit = __quantum__rt__qubit_allocate(); + H(q); + if M(q) == One { + X(q); + { + let + @generated_ident_42 : Unit = (); + __quantum__rt__qubit_release(q); + @generated_ident_42 + } + + } else { + Y(q); + __quantum__rt__qubit_release(q); + } + + } + } + function Length(a : Pauli[]) : Int { + body intrinsic; + } + function Length(a : Qubit[]) : Int { + body intrinsic; + } + // entry + Main() + "#]], + ); +} + +#[test] +fn bare_return_with_dead_code() { + // `return x; dead_code;` — apply_bare_return must truncate statements + // after the return. + check_no_returns_q( + indoc! {r#" + namespace Test { + operation Main() : Int { + use q = Qubit(); + H(q); + return 42; + let x = 1; + x + 2 + } + } + "#}, + &expect![[r#" + // namespace Test + operation Main() : Int { + body { + let q : Qubit = __quantum__rt__qubit_allocate(); + H(q); + { + let + @generated_ident_33 : Int = 42; + __quantum__rt__qubit_release(q); + @generated_ident_33 + } + + } + } + function Length(a : Qubit[]) : Int { + body intrinsic; + } + // entry + Main() + "#]], + ); +} + +#[test] +fn nested_if_with_returns_at_different_levels() { + // Returns at two levels of if nesting: the innermost if-return is lifted + // first, then the outer if-return is lifted. + check_no_returns_q( + indoc! {r#" + namespace Test { + function Main() : Int { + if true { + if false { + return 1; + } + return 2; + } + 3 + } + } + "#}, + &expect![[r#" + // namespace Test + function Main() : Int { + body { + mutable __has_returned : Bool = false; + mutable __ret_val : Int = 0; + if true { + if false { + { + __ret_val = 1; + __has_returned = true; + }; + } + + if not __has_returned { + { + __ret_val = 2; + __has_returned = true; + }; + }; + } + + let __trailing_result : Int = 3; + if __has_returned __ret_val else __trailing_result + } + } + // entry + Main() + "#]], + ); +} + +#[test] +fn return_tuple_value() { + // Return of a compound (tuple) type exercises type propagation + // through strip_returns_from_expr. + check_no_returns_q( + indoc! {r#" + namespace Test { + function Main() : (Int, Bool) { + if true { + return (1, true); + } + (0, false) + } + } + "#}, + &expect![[r#" + // namespace Test + function Main() : (Int, Bool) { + body { + if true { + (1, true) + } else { + (0, false) + } + + } + } + // entry + Main() + "#]], + ); +} + +#[test] +fn guard_clause_with_existing_else_and_remaining() { + // if-return with an existing else body AND remaining statements after + // the if — exercises apply_if_then_return's else prepend path. + check_no_returns_q( + indoc! {r#" + namespace Test { + function Main() : Int { + if true { + return 1; + } else { + let _ = 0; + } + 2 + } + } + "#}, + &expect![[r#" + // namespace Test + function Main() : Int { + body { + if true { + 1 + } else { + { + let _ : Int = 0; + }; + 2 + } + + } + } + // entry + Main() + "#]], + ); +} + +#[test] +fn deeply_nested_block_with_return() { + // Return inside multiple levels of nested blocks exercises + // NestedBlock recursion in classify_return_stmt. + check_no_returns_q( + indoc! {r#" + namespace Test { + function Main() : Int { + let x = { + if true { + return 10; + } + 5 + }; + x + } + } + "#}, + &expect![[r#" + // namespace Test + function Main() : Int { + body { + mutable __has_returned : Bool = false; + mutable __ret_val : Int = 0; + let x : Int = { + if true { + { + __ret_val = 10; + __has_returned = true; + }; + } + + 5 + }; + let __trailing_result : Int = x; + if __has_returned __ret_val else __trailing_result + } + } + // entry + Main() + "#]], + ); +} + +#[test] +fn return_after_dynamic_branch_with_dead_code() { + // Dynamic branch followed by early return followed by dead code. + // Exercises BareReturn truncation after a non-classical if-else. + check_no_returns_q( + indoc! {r#" + namespace Test { + operation Main() : Unit { + use q = Qubit(); + if M(q) == One { + X(q); + } else { + H(q); + } + H(q); + return (); + Y(q); + } + } + "#}, + &expect![[r#" + // namespace Test + operation Main() : Unit { + body { + let q : Qubit = __quantum__rt__qubit_allocate(); + if M(q) == One { + X(q); + } else { + H(q); + } + + H(q); + { + let + @generated_ident_48 : Unit = (); + __quantum__rt__qubit_release(q); + @generated_ident_48 + } + + } + } + function Length(a : Pauli[]) : Int { + body intrinsic; + } + function Length(a : Qubit[]) : Int { + body intrinsic; + } + // entry + Main() + "#]], + ); +} + +#[test] +fn for_loop_with_early_return() { + // For loops desugar to a block wrapping locals + while in FIR. + // The While is nested inside a Block expression, so transform_while_stmt + // must descend through Block wrappers to find and transform it. + check_no_returns_q( + indoc! {r#" + namespace Test { + function Main() : Int { + for i in 0..10 { + if i == 5 { + return i; + } + } + -1 + } + } + "#}, + &expect![[r#" + // namespace Test + function Main() : Int { + body { + mutable __has_returned : Bool = false; + mutable __ret_val : Int = 0; + { + let + @range_id_30 : Range = 0..10; + mutable + @index_id_33 : Int = + @range_id_30::Start; + let + @step_id_38 : Int = + @range_id_30::Step; + let + @end_id_43 : Int = + @range_id_30::End; + while not __has_returned and + @step_id_38 > 0 and + @index_id_33 <= + @end_id_43 or + @step_id_38 < 0 and + @index_id_33 >= + @end_id_43 { + let i : Int = + @index_id_33; + if i == 5 { + { + __ret_val = i; + __has_returned = true; + }; + } + + if not __has_returned { + @index_id_33 += + @step_id_38; + }; + } + + } + + let __trailing_result : Int = -1; + if __has_returned __ret_val else __trailing_result + } + } + // entry + Main() + "#]], + ); +} + +#[test] +fn nested_qubit_scope_return_updates_outer_block_type() { + check_structure( + indoc! {r#" + namespace Test { + import Std.Measurement.*; + + @EntryPoint() + operation Main() : Result { + use outer = Qubit() { + use qubit = Qubit() { + let result = MResetZ(qubit); + Reset(outer); + return result; + } + } + } + } + "#}, + &["Main"], + &expect![[r#" + callable Main: input_ty=Unit, output_ty=Result + body: block_ty=Result + [0] Expr Block[ty=Result]"#]], + ); +} + +#[test] +fn early_return_in_qubit_array_scope_preserves_release_order() { + let source = indoc! {r#" + namespace Test { + operation Foo(flag : Bool) : Int { + use qs = Qubit[2]; + if flag { + return 1; + } + 0 + } + + @EntryPoint() + operation Main() : Int { + Foo(true) + } + } + "#}; + + let (store, pkg_id) = compile_return_unified(source); + let package = store.get(pkg_id); + let body_block_id = find_body_block_id(package, "Foo"); + let body_block = package.get_block(body_block_id); + let has_path_local_array_release = body_block.stmts.iter().any(|&stmt_id| { + stmt_tree_calls_named_callable(&store, package, stmt_id, "ReleaseQubitArray") + }); + assert!( + has_path_local_array_release, + "Foo body should preserve ReleaseQubitArray on value-producing paths" + ); + + let has_unconditional_array_release_suffix = body_block + .stmts + .iter() + .any(|&stmt_id| stmt_calls_named_callable(&store, package, stmt_id, "ReleaseQubitArray")); + assert!( + !has_unconditional_array_release_suffix, + "Foo body should not keep an unconditional ReleaseQubitArray suffix after path-local releases" + ); +} + +#[test] +fn classify_semi_return_and_expr_return_produce_same_shape() { + let semi_source = indoc! {r#" + namespace Test { + @EntryPoint() + operation Main() : Int { + return 1; + } + } + "#}; + let expr_source = indoc! {r#" + namespace Test { + @EntryPoint() + operation Main() : Int { + return 1 + } + } + "#}; + + let (semi_store, semi_pkg_id) = compile_return_unified(semi_source); + let (expr_store, expr_pkg_id) = compile_return_unified(expr_source); + + let semi_summary = summarize_callable(semi_store.get(semi_pkg_id), "Main"); + let expr_summary = summarize_callable(expr_store.get(expr_pkg_id), "Main"); + assert_eq!( + semi_summary, expr_summary, + "Semi-Return and Expr-Return callables must produce identical post-return_unify shapes", + ); +} + +/// Flag-guarded stmt type check: `guard_stmt_with_flag` requires a +/// Unit-typed inner stmt. Passing a non-Unit `StmtKind::Expr` must trip +/// the debug assertion. Gated on debug builds because `debug_assert!` is +/// elided in release. +#[cfg(debug_assertions)] +#[test] +fn outer_return_wrapping_if_with_stmt_return_in_else_does_not_loop() { + check_structure( + indoc! {r#" + namespace Test { + import Std.Measurement.*; + + @EntryPoint() + operation Main() : Int { + use q = Qubit(); + return if M(q) == One { + 1 + } else { + return M(q) == One ? 0 | 1; + }; + } + } + "#}, + &["Main"], + &expect![[r#" + callable Main: input_ty=Unit, output_ty=Int + body: block_ty=Int + [0] Local(Immutable, q: Qubit): Call[ty=Qubit] + [1] Expr Block[ty=Int]"#]], + ); +} + +#[test] +fn outer_return_wrapping_if_with_stmt_return_in_else_full_pipeline() { + // Verify the full pipeline (including PostAll invariant checks) succeeds + // now that If expression types and Pat types are synchronized after + // return replacement. + let source = indoc! {r#" + namespace Test { + import Std.Measurement.*; + + @EntryPoint() + operation Main() : Int { + use q = Qubit(); + return if M(q) == One { + 1 + } else { + return M(q) == One ? 0 | 1; + }; + } + } + "#}; + + let _ = compile_and_run_pipeline_to(source, PipelineStage::Full); +} + +#[test] +fn recursive_function_with_return() { + check_no_returns_q( + indoc! {r#" + namespace Test { + function Factorial(n : Int) : Int { + if n <= 1 { + return 1; + } + n * Factorial(n - 1) + } + function Main() : Int { + Factorial(5) + } + } + "#}, + &expect![[r#" + // namespace Test + function Factorial(n : Int) : Int { + body { + if n <= 1 { + 1 + } else { + n * Factorial(n - 1) + } + + } + } + function Main() : Int { + body { + Factorial(5) + } + } + // entry + Main() + "#]], + ); +} + +#[test] +fn fail_and_return_in_same_control_flow() { + check_no_returns_q( + indoc! {r#" + namespace Test { + function Main() : Int { + let c = true; + if c { + return 42; + } else { + fail "unreachable"; + } + } + } + "#}, + &expect![[r#" + // namespace Test + function Main() : Int { + body { + let c : Bool = true; + if c { + 42 + } else { + fail $"unreachable"; + } + + } + } + // entry + Main() + "#]], + ); +} + +// Arrow-typed return in structured path + +#[test] +fn arrow_typed_return_in_structured_path() { + check_no_returns_q( + indoc! {r#" + namespace Test { + function Choose(flag : Bool) : (Int -> Int) { + if flag { + return x -> x + 1; + } + x -> x * 2 + } + function Main() : Int { + let f = Choose(true); + f(10) + } + } + "#}, + &expect![[r#" + // namespace Test + function Choose(flag : Bool) : (Int -> Int) { + body { + if flag { + / * closure item = 3 captures = [] * / < lambda > + } else { + / * closure item = 4 captures = [] * / < lambda > + } + + } + } + function Main() : Int { + body { + let f : (Int -> Int) = Choose(true); + f(10) + } + } + function < lambda > (x : Int, ) : Int { + body { + x + 1 + } + } + function < lambda > (x : Int, ) : Int { + body { + x * 2 + } + } + // entry + Main() + "#]], + ); +} + +// semantic test omitted: the program returns callable values which +// trigger a defunctionalization convergence failure in the full pipeline. +// The structural test above validates that return_unify handles this pattern. + +// Qubit return + while — triggers the error path + +#[test] +fn simple_if_expr_init_with_return_stays_structured() { + // Simple return directly in an if-branch initializer — the structured + // strategy handles this via strip_returns_from_expr without flags. + check_no_returns_q( + indoc! {r#" + namespace Test { + function Main() : Int { + if true { + return 10; + } + let x = if false { return 20; } else { 30 }; + x + } + } + "#}, + &expect![[r#" + // namespace Test + function Main() : Int { + body { + if true { + 10 + } else { + if false { + 20 + } else { + let x : Int = { + 30 + }; + x + } + + } + + } + } + // entry + Main() + "#]], + ); +} diff --git a/source/compiler/qsc_fir_transforms/src/return_unify/tests/type_preservation.rs b/source/compiler/qsc_fir_transforms/src/return_unify/tests/type_preservation.rs new file mode 100644 index 0000000000..7b4e2e79d9 --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/return_unify/tests/type_preservation.rs @@ -0,0 +1,86 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use super::*; + +#[test] +fn type_preservation_structured_strategy() { + // Structured strategy rewrites block tails — invariant checked in pipeline. + let (_store, _pkg_id) = compile_return_unified(indoc! {r#" + namespace Test { + function Main() : Int { + if true { + return 1; + } + 0 + } + } + "#}); +} + +#[test] +fn type_preservation_flag_strategy_int() { + // Flag strategy with Int return — invariant checked in pipeline. + let (_store, _pkg_id) = compile_return_unified(indoc! {r#" + namespace Test { + function Main() : Int { + mutable i = 0; + while i < 10 { + if i == 5 { + return i; + } + i += 1; + } + i + } + } + "#}); +} + +#[test] +fn type_preservation_tuple_return() { + // Tuple return type — invariant checked in pipeline. + let (_store, _pkg_id) = compile_return_unified(indoc! {r#" + namespace Test { + function Main() : (Int, Bool) { + if true { + return (1, true); + } + (0, false) + } + } + "#}); +} + +#[test] +fn type_preservation_nested_block_expr() { + // Nested block expression return — invariant checked in pipeline. + let (_store, _pkg_id) = compile_return_unified(indoc! {r#" + namespace Test { + function Main() : Int { + let x = { + if true { + return 1; + } + 2 + }; + x + } + } + "#}); +} + +#[test] +fn type_preservation_double_return() { + // Double return type — invariant checked in pipeline. + let (_store, _pkg_id) = compile_return_unified(indoc! {r#" + namespace Test { + function Main() : Double { + if true { + return 1.0; + } + 2.0 + } + } + "#}); +} diff --git a/source/compiler/qsc_fir_transforms/src/sroa.rs b/source/compiler/qsc_fir_transforms/src/sroa.rs new file mode 100644 index 0000000000..46ba57dee5 --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/sroa.rs @@ -0,0 +1,826 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! Scalar Replacement of Aggregates (SROA) pass. +//! +//! Replaces local variables of tuple type with individual scalar variables, +//! eliminating intermediate tuple allocations and field-access overhead. +//! +//! Establishes [`crate::invariants::InvariantLevel::PostSroa`]: +//! synthesized local tuple patterns agree with the tuple types they +//! decompose. +//! +//! ## Prerequisites +//! +//! The `tuple_compare_lower` pass must run before SROA. It rewrites +//! equality and inequality on non-empty tuples into element-wise scalar +//! comparisons, which eliminates whole-value uses that would otherwise +//! prevent decomposition. +//! +//! ## Decomposition +//! +//! For each entry-reachable callable, the pass: +//! - Identifies local bindings whose type is `Ty::Tuple(...)` or +//! `Ty::Udt(Res::Item(_))` (resolving to a multi-field UDT) and whose +//! every use is `ExprKind::Field`, `ExprKind::AssignField`, or +//! `ExprKind::Assign(Var, Tuple)` (whole-tuple reassignment with a +//! tuple-literal RHS). +//! - Decomposes those bindings in-place: `PatKind::Bind(t)` becomes +//! `PatKind::Tuple([Bind(t_0), Bind(t_1), ...])`, field accesses become +//! direct variable references, and whole-tuple assignments are split into +//! per-element assignments. +//! +//! The eligibility criterion is conservative: a binding is decomposed only +//! when *all* of its uses are field-only accesses or decomposable tuple +//! assignments. If the variable is ever passed as a whole value (e.g., as +//! an argument, a return value, or a closure capture), it is left intact. +//! +//! ## Iterative fixed-point +//! +//! The pass runs iteratively to a fixed point. Each iteration peels one +//! level of nesting: `Bind(t: (A, B))` → `Tuple([Bind(t_0: A), Bind(t_1: B)])`. +//! When `A` is itself a tuple (e.g., `(Int, Int)`), the next iteration +//! discovers `Bind(t_0: (Int, Int))` as a new candidate and decomposes it +//! further. The loop terminates when no new candidates remain. +//! +//! # Input patterns +//! +//! - `let t = (a, b, c);` where every later reference is `t::0`, `t::1`, +//! `t::2`, `t = (a', b', c')`, or `t::0 = a'`. +//! +//! # Rewrites +//! +//! ```text +//! // Before +//! let t = (a, b, c); +//! let x = t::1; +//! t = (a', b', c'); +//! +//! // After +//! let (t_0, t_1, t_2) = (a, b, c); +//! let x = t_1; +//! (t_0, t_1, t_2) = (a', b', c'); +//! ``` +//! +//! # Notes +//! +//! - Synthesized expressions use `EMPTY_EXEC_RANGE`; +//! [`crate::exec_graph_rebuild`] rebuilds correct exec graphs at the end +//! of the pipeline. + +#[cfg(test)] +mod tests; + +#[cfg(all(test, feature = "slow-proptest-tests"))] +mod semantic_equivalence_tests; + +use crate::fir_builder::{ + alloc_local_var_expr, decompose_binding, functored_specs, reachable_local_callables, + resolve_udt_element_types, +}; +use crate::reachability::collect_reachable_from_entry; +use crate::walk_utils::{collect_expr_ids_in_local_callables, collect_uses_in_block}; +use qsc_data_structures::span::Span; +use qsc_fir::assigner::Assigner; +use qsc_fir::fir::{ + BlockId, CallableDecl, CallableImpl, Expr, ExprId, ExprKind, Field, FieldPath, ItemKind, + LocalItemId, LocalVarId, Package, PackageId, PackageLookup, PackageStore, PatId, PatKind, Res, + SpecDecl, SpecImpl, Stmt, StmtId, StmtKind, +}; +use qsc_fir::ty::Ty; +use rustc_hash::FxHashMap; +use std::rc::Rc; + +use crate::EMPTY_EXEC_RANGE; + +/// Runs the SROA pass on the entry-reachable portion of a package. +/// +/// For each local binding of tuple type where every use is a field access +/// or field assignment, decomposes the binding into individual scalar +/// variables and rewrites field accesses into direct variable references. +pub fn sroa(store: &mut PackageStore, package_id: PackageId, assigner: &mut Assigner) { + let package = store.get(package_id); + if package.entry.is_none() { + return; + } + + loop { + let reachable = collect_reachable_from_entry(store, package_id); + let package = store.get(package_id); + + // Collect candidates across all reachable callables. + let mut all_candidates: Vec = Vec::new(); + + for (item_id, decl) in reachable_local_callables(package, package_id, &reachable) { + collect_candidates_in_callable(store, package_id, item_id, decl, &mut all_candidates); + } + + if all_candidates.is_empty() { + break; + } + + // Apply decomposition. + let package = store.get_mut(package_id); + for candidate in &all_candidates { + decompose_candidate(package, assigner, candidate); + } + } +} + +/// A candidate for SROA decomposition. +struct SroaCandidate { + /// The `LocalVarId` bound by the original `PatKind::Bind`. + local_id: LocalVarId, + /// The `PatId` of the binding pattern. + pat_id: PatId, + /// Element types from the tuple. + elem_types: Vec, + /// The name of the original binding. + name: Rc, + /// The callable item that owns this local binding. + owner_item: LocalItemId, +} + +/// Scans a callable's body for SROA candidates. +fn collect_candidates_in_callable( + store: &PackageStore, + package_id: PackageId, + owner_item: LocalItemId, + decl: &CallableDecl, + candidates: &mut Vec, +) { + match &decl.implementation { + CallableImpl::Intrinsic => {} + CallableImpl::Spec(spec_impl) => { + collect_candidates_in_spec_impl(store, package_id, owner_item, spec_impl, candidates); + } + CallableImpl::SimulatableIntrinsic(spec) => { + collect_candidates_in_spec(store, package_id, owner_item, spec, candidates); + } + } +} + +/// Recurses into every specialization of a `SpecImpl` to collect SROA +/// candidates. +fn collect_candidates_in_spec_impl( + store: &PackageStore, + package_id: PackageId, + owner_item: LocalItemId, + spec_impl: &SpecImpl, + candidates: &mut Vec, +) { + collect_candidates_in_spec(store, package_id, owner_item, &spec_impl.body, candidates); + for spec in functored_specs(spec_impl) { + collect_candidates_in_spec(store, package_id, owner_item, spec, candidates); + } +} + +/// Collects SROA candidates within a single `SpecDecl` body by walking +/// tuple-typed bindings and checking every use for field-only or +/// decomposable-tuple-assignment eligibility. +fn collect_candidates_in_spec( + store: &PackageStore, + package_id: PackageId, + owner_item: LocalItemId, + spec: &SpecDecl, + candidates: &mut Vec, +) { + let package = store.get(package_id); + // Collect all local bindings with composite (tuple or UDT) type. + let bindings = find_tuple_bindings_in_block(store, package_id, spec.block); + + for binding in bindings { + // Verify ALL uses are field-only. + if all_uses_are_field_access(package, spec.block, binding.local_id) { + candidates.push(SroaCandidate { + local_id: binding.local_id, + pat_id: binding.pat_id, + elem_types: binding.elem_types, + name: binding.name, + owner_item, + }); + } + } +} + +/// Information about a tuple-typed local binding. +struct TupleBinding { + local_id: LocalVarId, + pat_id: PatId, + elem_types: Vec, + name: Rc, +} + +/// Recursively walks a pattern to find `PatKind::Bind` nodes with tuple or +/// UDT types. This handles patterns produced by a previous SROA pass that +/// transformed `Bind(t)` into `Tuple([Bind(t_0), Bind(t_1), ...])` — the +/// inner `Bind(t_0)` would otherwise be invisible to the scanner. +fn find_binds_in_pat( + store: &PackageStore, + package_id: PackageId, + pat_id: PatId, + bindings: &mut Vec, +) { + let package = store.get(package_id); + let pat = package.get_pat(pat_id); + match &pat.kind { + PatKind::Bind(ident) => { + let elem_types = match &pat.ty { + Ty::Tuple(elems) if !elems.is_empty() => Some(elems.clone()), + Ty::Udt(Res::Item(item_id)) => resolve_udt_element_types(store, item_id), + _ => None, + }; + if let Some(elem_types) = elem_types { + bindings.push(TupleBinding { + local_id: ident.id, + pat_id, + elem_types, + name: ident.name.clone(), + }); + } + } + PatKind::Tuple(sub_pats) => { + for &sub_pat_id in sub_pats { + find_binds_in_pat(store, package_id, sub_pat_id, bindings); + } + } + PatKind::Discard => {} + } +} + +/// Finds all `StmtKind::Local(_, pat, _)` in a block where `pat` is +/// `PatKind::Bind(ident)` with `Ty::Tuple(elems)` or `Ty::Udt(Res::Item(_))` +/// resolving to a multi-field UDT, and the composite type is non-empty. +fn find_tuple_bindings_in_block( + store: &PackageStore, + package_id: PackageId, + block_id: BlockId, +) -> Vec { + let mut bindings = Vec::new(); + find_tuple_bindings_recursive(store, package_id, block_id, &mut bindings); + bindings +} + +/// Walks a block (recursively through nested statements and expressions) +/// collecting every candidate tuple-typed binding into `bindings`. +fn find_tuple_bindings_recursive( + store: &PackageStore, + package_id: PackageId, + block_id: BlockId, + bindings: &mut Vec, +) { + let package = store.get(package_id); + let block = package.get_block(block_id); + for &stmt_id in &block.stmts { + let stmt = package.get_stmt(stmt_id); + match &stmt.kind { + StmtKind::Local(_, pat_id, expr_id) => { + find_binds_in_pat(store, package_id, *pat_id, bindings); + // Recurse into nested blocks in the RHS expression. + find_tuple_bindings_in_expr_id(store, package_id, *expr_id, bindings); + } + StmtKind::Expr(e) | StmtKind::Semi(e) => { + find_tuple_bindings_in_expr_id(store, package_id, *e, bindings); + } + StmtKind::Item(_) => {} + } + } +} + +/// Descends into an expression subtree collecting candidate bindings from +/// nested blocks, conditionals, while-loops, and match-like constructs. +fn find_tuple_bindings_in_expr_id( + store: &PackageStore, + package_id: PackageId, + expr_id: ExprId, + bindings: &mut Vec, +) { + let package = store.get(package_id); + let expr = package.get_expr(expr_id); + match &expr.kind { + ExprKind::Block(block_id) | ExprKind::While(_, block_id) => { + find_tuple_bindings_recursive(store, package_id, *block_id, bindings); + } + ExprKind::If(_, body, otherwise) => { + find_tuple_bindings_in_expr_id(store, package_id, *body, bindings); + if let Some(e) = otherwise { + find_tuple_bindings_in_expr_id(store, package_id, *e, bindings); + } + } + _ => {} + } +} + +/// Returns `true` if every use of `local_id` in the block is a field access +/// (`ExprKind::Field(Var(Local(id)), Path(_))`) or a field assignment +/// (`ExprKind::AssignField(Var(Local(id)), _, _)`). +/// +/// Returns `false` if `local_id` is used in any other context: passed as an +/// argument, returned, captured by closure, assigned whole, etc. +fn all_uses_are_field_access(package: &Package, block_id: BlockId, local_id: LocalVarId) -> bool { + let mut uses = Vec::new(); + collect_uses_in_block(package, block_id, local_id, &mut uses); + uses.iter().all(|u| *u) +} + +/// Decomposes a single SROA candidate in-place. +/// +/// # Before +/// ```text +/// let t : (A, B) = (a, b); // single tuple binding +/// use(t.0); use(t.1); // only field accesses +/// ``` +/// # After +/// ```text +/// let (t_0, t_1) : (A, B) = (a, b); // binding split to scalars +/// use(t_0); use(t_1); // field accesses → direct vars +/// ``` +/// +/// # Mutations +/// - Rewrites the binding `Pat` from `Bind` to `Tuple` of per-element `Bind`s. +/// - Allocates new `LocalVarId`, `PatId` nodes through `assigner`. +/// - Delegates to [`rewrite_field_accesses`] and [`rewrite_assign_tuples`]. +fn decompose_candidate(package: &mut Package, assigner: &mut Assigner, candidate: &SroaCandidate) { + let new_locals = decompose_binding( + package, + assigner, + candidate.pat_id, + &candidate.name, + &candidate.elem_types, + ); + + // Rewrite all field accesses and assign-field expressions. + rewrite_field_accesses( + package, + assigner, + candidate.owner_item, + candidate.local_id, + &new_locals, + &candidate.elem_types, + ); + + // Split `Assign(Var(Local(old)), Tuple([e0, e1, ...]))` into per-element + // assignments. This must run AFTER field access rewriting so that any + // `Field(Var(Local(old)), Path([i]))` references in the RHS elements + // have already been rewritten to `Var(Local(new_i))`. + rewrite_assign_tuples( + package, + assigner, + candidate.owner_item, + candidate.local_id, + &new_locals, + &candidate.elem_types, + ); +} + +/// Rewrites all `ExprKind::Field(Var(Local(old)), Path([i, ...]))` and +/// `ExprKind::AssignField(Var(Local(old)), Path([i, ...]), value)` uses across +/// the entire package so they target the decomposed scalar or nested aggregate +/// for the first path segment. +/// +/// # Before +/// ```text +/// Field(Var(Local(old)), Path([1])) // tuple.1 +/// ``` +/// # After +/// ```text +/// Var(Local(old_1)) // direct scalar reference +/// ``` +/// +/// # Mutations +/// - Allocates replacement `Var` and `Field` `Expr` nodes through `assigner`. +/// - Redirects all parent references from old to new via +/// [`replace_expr_references`]. +fn rewrite_field_accesses( + package: &mut Package, + assigner: &mut Assigner, + owner_item: LocalItemId, + old_local: LocalVarId, + new_locals: &[LocalVarId], + elem_types: &[Ty], +) { + // Collect ExprIds only from the owning callable (locals cannot escape). + let expr_ids = collect_expr_ids_in_local_callables(&*package, &[owner_item]); + + for expr_id in expr_ids { + rewrite_single_expr( + package, assigner, owner_item, expr_id, old_local, new_locals, elem_types, + ); + } +} + +/// Rewrites a single expression to replace references to an SROA-decomposed +/// tuple local with references to its scalar replacements. +/// +/// Handles two `ExprKind::Field` cases: +/// +/// - **Single-index path** (`t.i`): synthesize a fresh `Var(t_i)` expression +/// and redirect references to the old projection expression to it. +/// - **Nested path** (`t.i.j...`): synthesize a fresh `Var(t_i)` expression +/// and a fresh `Field(.., Path([j, ...]))` wrapper. Redirecting references +/// instead of mutating the original projection keeps shared expression nodes +/// stable for sibling projections created by earlier passes. +#[allow(clippy::too_many_lines)] +fn rewrite_single_expr( + package: &mut Package, + assigner: &mut Assigner, + owner_item: LocalItemId, + expr_id: ExprId, + old_local: LocalVarId, + new_locals: &[LocalVarId], + elem_types: &[Ty], +) { + let expr = package.exprs.get(expr_id).expect("expr should exist"); + match expr.kind.clone() { + ExprKind::Field(inner_id, Field::Path(path)) => { + let span = expr.span; + let expr_ty = expr.ty.clone(); + let inner = package + .exprs + .get(inner_id) + .expect("inner expr should exist"); + if let ExprKind::Var(Res::Local(var_id), _) = &inner.kind + && *var_id == old_local + && !path.indices.is_empty() + { + let idx = path.indices[0]; + if idx < new_locals.len() { + let new_local = new_locals[idx]; + if path.indices.len() == 1 { + let replacement_id = { + let ty = elem_types[idx].clone(); + alloc_local_var_expr(package, assigner, new_local, ty, span) + }; + replace_expr_references(package, owner_item, expr_id, replacement_id); + } else { + // Nested: t.i.j... -> Field(Var(t_i), Path([j, ...])) + let remaining: Vec = path.indices[1..].to_vec(); + + let new_inner_id = { + let ty = elem_types[idx].clone(); + alloc_local_var_expr(package, assigner, new_local, ty, span) + }; + let replacement_id = assigner.next_expr(); + package.exprs.insert( + replacement_id, + Expr { + id: replacement_id, + span, + ty: expr_ty, + kind: ExprKind::Field( + new_inner_id, + Field::Path(FieldPath { indices: remaining }), + ), + exec_graph_range: EMPTY_EXEC_RANGE, + }, + ); + replace_expr_references(package, owner_item, expr_id, replacement_id); + } + } + } + } + ExprKind::AssignField(record_id, Field::Path(path), value_id) => { + let span = expr.span; + let expr_ty = expr.ty.clone(); + let record = package + .exprs + .get(record_id) + .expect("record expr should exist"); + if let ExprKind::Var(Res::Local(var_id), _) = &record.kind + && *var_id == old_local + && !path.indices.is_empty() + { + let idx = path.indices[0]; + if idx < new_locals.len() { + let new_local = new_locals[idx]; + let new_record_id = { + let ty = elem_types[idx].clone(); + alloc_local_var_expr(package, assigner, new_local, ty, span) + }; + + let replacement_id = assigner.next_expr(); + let replacement_kind = if path.indices.len() == 1 { + ExprKind::Assign(new_record_id, value_id) + } else { + ExprKind::AssignField( + new_record_id, + Field::Path(FieldPath { + indices: path.indices[1..].to_vec(), + }), + value_id, + ) + }; + package.exprs.insert( + replacement_id, + Expr { + id: replacement_id, + span, + ty: expr_ty, + kind: replacement_kind, + exec_graph_range: EMPTY_EXEC_RANGE, + }, + ); + replace_expr_references(package, owner_item, expr_id, replacement_id); + } + } + } + _ => {} + } +} + +/// Rewrites every reference to `old_expr_id` in the owner callable to point at +/// `new_expr_id`. +/// +/// Before, entry, statements, and parent expressions still point at the +/// aggregate expression that SROA wants to replace. After, every such edge +/// points at the scalarized replacement, allowing the old node to become dead. +fn replace_expr_references( + package: &mut Package, + owner_item: LocalItemId, + old_expr_id: ExprId, + new_expr_id: ExprId, +) { + if package.entry == Some(old_expr_id) { + package.entry = Some(new_expr_id); + } + + // Collect owner's block IDs and expr IDs with immutable borrow, then mutate. + let (block_ids, expr_ids) = { + let blocks = collect_all_block_ids_in_callable(&*package, owner_item); + let exprs = collect_expr_ids_in_local_callables(&*package, &[owner_item]); + (blocks, exprs) + }; + + for block_id in &block_ids { + let stmts: Vec = package.get_block(*block_id).stmts.clone(); + for stmt_id in stmts { + let stmt = package.stmts.get_mut(stmt_id).expect("stmt should exist"); + replace_expr_in_stmt(stmt, old_expr_id, new_expr_id); + } + } + + for expr_id in expr_ids { + let expr = package.exprs.get_mut(expr_id).expect("expr should exist"); + replace_expr_in_expr(expr, old_expr_id, new_expr_id); + } +} + +fn replace_expr_in_stmt(stmt: &mut Stmt, old_expr_id: ExprId, new_expr_id: ExprId) { + match &mut stmt.kind { + StmtKind::Expr(expr_id) | StmtKind::Semi(expr_id) | StmtKind::Local(_, _, expr_id) => { + replace_expr_id(expr_id, old_expr_id, new_expr_id); + } + StmtKind::Item(_) => {} + } +} + +fn replace_expr_in_expr(expr: &mut Expr, old_expr_id: ExprId, new_expr_id: ExprId) { + match &mut expr.kind { + ExprKind::Array(exprs) | ExprKind::ArrayLit(exprs) | ExprKind::Tuple(exprs) => { + for expr_id in exprs { + replace_expr_id(expr_id, old_expr_id, new_expr_id); + } + } + ExprKind::ArrayRepeat(a, b) + | ExprKind::Assign(a, b) + | ExprKind::AssignOp(_, a, b) + | ExprKind::BinOp(_, a, b) + | ExprKind::Call(a, b) + | ExprKind::Index(a, b) + | ExprKind::AssignField(a, _, b) + | ExprKind::UpdateField(a, _, b) => { + replace_expr_id(a, old_expr_id, new_expr_id); + replace_expr_id(b, old_expr_id, new_expr_id); + } + ExprKind::AssignIndex(a, b, c) | ExprKind::UpdateIndex(a, b, c) => { + replace_expr_id(a, old_expr_id, new_expr_id); + replace_expr_id(b, old_expr_id, new_expr_id); + replace_expr_id(c, old_expr_id, new_expr_id); + } + ExprKind::Fail(expr_id) + | ExprKind::Field(expr_id, _) + | ExprKind::Return(expr_id) + | ExprKind::UnOp(_, expr_id) => { + replace_expr_id(expr_id, old_expr_id, new_expr_id); + } + ExprKind::If(cond, body, otherwise) => { + replace_expr_id(cond, old_expr_id, new_expr_id); + replace_expr_id(body, old_expr_id, new_expr_id); + if let Some(expr_id) = otherwise { + replace_expr_id(expr_id, old_expr_id, new_expr_id); + } + } + ExprKind::Range(start, step, end) => { + for expr_id in [start, step, end].into_iter().flatten() { + replace_expr_id(expr_id, old_expr_id, new_expr_id); + } + } + ExprKind::Struct(_, copy, fields) => { + if let Some(expr_id) = copy { + replace_expr_id(expr_id, old_expr_id, new_expr_id); + } + for field in fields { + replace_expr_id(&mut field.value, old_expr_id, new_expr_id); + } + } + ExprKind::String(components) => { + for component in components { + if let qsc_fir::fir::StringComponent::Expr(expr_id) = component { + replace_expr_id(expr_id, old_expr_id, new_expr_id); + } + } + } + ExprKind::While(cond, _) => { + replace_expr_id(cond, old_expr_id, new_expr_id); + } + ExprKind::Block(_) + | ExprKind::Closure(_, _) + | ExprKind::Hole + | ExprKind::Lit(_) + | ExprKind::Var(_, _) => {} + } +} + +fn replace_expr_id(expr_id: &mut ExprId, old_expr_id: ExprId, new_expr_id: ExprId) { + if *expr_id == old_expr_id { + *expr_id = new_expr_id; + } +} + +/// Builds a mapping from `StmtId` → `BlockId` for the owner callable's blocks. +fn build_stmt_block_map_for_callable( + package: &Package, + item_id: LocalItemId, +) -> FxHashMap { + let mut map = FxHashMap::default(); + let block_ids = collect_all_block_ids_in_callable(package, item_id); + for block_id in block_ids { + let block = package.get_block(block_id); + for &stmt_id in &block.stmts { + map.insert(stmt_id, block_id); + } + } + map +} + +/// Collects all block IDs reachable from a callable's implementation. +fn collect_all_block_ids_in_callable(package: &Package, item_id: LocalItemId) -> Vec { + let Some(item) = package.items.get(item_id) else { + return Vec::new(); + }; + let ItemKind::Callable(decl) = &item.kind else { + return Vec::new(); + }; + let mut block_ids = Vec::new(); + // Include spec-level blocks. + match &decl.implementation { + CallableImpl::Intrinsic => {} + CallableImpl::Spec(spec_impl) => { + block_ids.push(spec_impl.body.block); + for spec in functored_specs(spec_impl) { + block_ids.push(spec.block); + } + } + CallableImpl::SimulatableIntrinsic(spec) => { + block_ids.push(spec.block); + } + } + // Include nested blocks found via expression walking. + crate::walk_utils::for_each_expr_in_callable_impl( + package, + &decl.implementation, + &mut |_, expr| match &expr.kind { + ExprKind::Block(bid) | ExprKind::While(_, bid) => { + block_ids.push(*bid); + } + _ => {} + }, + ); + block_ids +} + +/// Splits `Assign(Var(Local(old)), Tuple([e0, e1, ...]))` into per-element +/// assignments across the containing block. +/// +/// # Before +/// ```text +/// set old = (a, b); // single Semi(Assign(..)) statement +/// ``` +/// # After +/// ```text +/// set old_0 = a; // original stmt rewritten in-place +/// set old_1 = b; // new stmt inserted after +/// ``` +/// +/// # Mutations +/// - Rewrites the original `Assign` `ExprKind` in-place for element 0. +/// - Allocates new `Expr` and `Stmt` nodes for elements 1..n-1. +/// - Inserts new statements into the containing block after the original. +fn rewrite_assign_tuples( + package: &mut Package, + assigner: &mut Assigner, + owner_item: LocalItemId, + old_local: LocalVarId, + new_locals: &[LocalVarId], + elem_types: &[Ty], +) { + let stmt_block_map = build_stmt_block_map_for_callable(package, owner_item); + + // Collect (stmt_id, expr_id, elements) for all matching Assign-Tuple patterns. + let mut rewrites: Vec<(StmtId, ExprId, Vec)> = Vec::new(); + + for &stmt_id in stmt_block_map.keys() { + let stmt = package.stmts.get(stmt_id).expect("stmt should exist"); + let semi_expr_id = match &stmt.kind { + StmtKind::Semi(e) => *e, + _ => continue, + }; + let expr = package.exprs.get(semi_expr_id).expect("expr should exist"); + if let ExprKind::Assign(lhs_id, rhs_id) = &expr.kind { + let lhs = package.exprs.get(*lhs_id).expect("lhs should exist"); + if let ExprKind::Var(Res::Local(var_id), _) = &lhs.kind + && *var_id == old_local + { + let rhs = package.exprs.get(*rhs_id).expect("rhs should exist"); + if let ExprKind::Tuple(elements) = &rhs.kind { + rewrites.push((stmt_id, semi_expr_id, elements.clone())); + } + } + } + } + + for (stmt_id, assign_expr_id, elements) in rewrites { + let n = elements.len().min(new_locals.len()); + if n == 0 { + continue; + } + + // Rewrite the original Assign in-place to target the first element. + { + // Create a new Var expr for the first element's LHS. + let new_lhs_id = assigner.next_expr(); + let new_lhs = Expr { + id: new_lhs_id, + span: Span::default(), + ty: elem_types[0].clone(), + kind: ExprKind::Var(Res::Local(new_locals[0]), vec![]), + exec_graph_range: EMPTY_EXEC_RANGE, + }; + package.exprs.insert(new_lhs_id, new_lhs); + + let assign = package + .exprs + .get_mut(assign_expr_id) + .expect("assign expr exists"); + assign.kind = ExprKind::Assign(new_lhs_id, elements[0]); + assign.ty = elem_types[0].clone(); + } + + // For elements 1..n, create new Assign exprs and Semi stmts. + let mut new_stmt_ids: Vec = Vec::with_capacity(n - 1); + for i in 1..n { + let lhs_id = assigner.next_expr(); + let lhs_expr = Expr { + id: lhs_id, + span: Span::default(), + ty: elem_types[i].clone(), + kind: ExprKind::Var(Res::Local(new_locals[i]), vec![]), + exec_graph_range: EMPTY_EXEC_RANGE, + }; + package.exprs.insert(lhs_id, lhs_expr); + + let assign_id = assigner.next_expr(); + let assign_expr = Expr { + id: assign_id, + span: Span::default(), + ty: elem_types[i].clone(), + kind: ExprKind::Assign(lhs_id, elements[i]), + exec_graph_range: EMPTY_EXEC_RANGE, + }; + package.exprs.insert(assign_id, assign_expr); + + let new_stmt_id = assigner.next_stmt(); + let new_stmt = Stmt { + id: new_stmt_id, + span: Span::default(), + kind: StmtKind::Semi(assign_id), + exec_graph_range: EMPTY_EXEC_RANGE, + }; + package.stmts.insert(new_stmt_id, new_stmt); + new_stmt_ids.push(new_stmt_id); + } + + // Insert the new stmts into the containing block after the original stmt. + if let Some(&block_id) = stmt_block_map.get(&stmt_id) { + let block = package + .blocks + .get_mut(block_id) + .expect("block should exist"); + if let Some(pos) = block.stmts.iter().position(|&s| s == stmt_id) { + for (offset, new_id) in new_stmt_ids.into_iter().enumerate() { + block.stmts.insert(pos + 1 + offset, new_id); + } + } + } + } +} diff --git a/source/compiler/qsc_fir_transforms/src/sroa/semantic_equivalence_tests.rs b/source/compiler/qsc_fir_transforms/src/sroa/semantic_equivalence_tests.rs new file mode 100644 index 0000000000..85800a7054 --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/sroa/semantic_equivalence_tests.rs @@ -0,0 +1,126 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use indoc::formatdoc; +use indoc::indoc; +use proptest::prelude::*; + +#[test] +fn tuple_local_split_preserves_semantics() { + crate::test_utils::check_semantic_equivalence(indoc! {r#" + namespace Test { + @EntryPoint() + function Main() : Int { + let pair = (10, 20); + let (a, b) = pair; + a + b + } + } + "#}); +} + +#[test] +fn struct_field_access_split_preserves_semantics() { + crate::test_utils::check_semantic_equivalence(indoc! {r#" + namespace Test { + struct Point { X : Int, Y : Int } + + @EntryPoint() + function Main() : Int { + let p = new Point { X = 3, Y = 7 }; + p.X * p.Y + } + } + "#}); +} + +#[test] +fn mutable_tuple_update_split_preserves_semantics() { + crate::test_utils::check_semantic_equivalence(indoc! {r#" + namespace Test { + @EntryPoint() + function Main() : Int { + mutable pair = (1, 2); + let (a, b) = pair; + set pair = (a + 10, b + 20); + let (c, d) = pair; + c + d + } + } + "#}); +} + +fn sroa_tuple_local_pattern() -> impl Strategy { + (2..=5usize, 1..=3usize).prop_map(|(width, depth)| { + let type_defs = sroa_struct_defs(width, depth); + let initial_value = sroa_struct_value(width, depth, 0); + let first_access = sroa_field_path(0, depth); + let last_access = sroa_field_path(width - 1, depth); + + formatdoc! {r#" + namespace Test {{ + {type_defs} + + @EntryPoint() + function Main() : Int {{ + let tupleValue = {initial_value}; + tupleValue.{first_access} + tupleValue.{last_access} + }} + }} + "#} + }) +} + +fn sroa_struct_defs(width: usize, depth: usize) -> String { + (1..=depth) + .map(|level| { + let field_ty = if level == 1 { + "Int".to_string() + } else { + format!("TupleLevel{}", level - 1) + }; + let fields = (0..width) + .map(|field_index| format!("F{field_index} : {field_ty}")) + .collect::>() + .join(", "); + format!(" struct TupleLevel{level} {{ {fields} }}") + }) + .collect::>() + .join("\n") +} + +fn sroa_struct_value(width: usize, level: usize, offset: usize) -> String { + let assignments = (0..width) + .map(|field_index| { + let value = if level == 1 { + (offset + field_index).to_string() + } else { + let stride = width.pow( + u32::try_from(level - 1) + .expect("Depth should be small enough to avoid overflow"), + ); + sroa_struct_value(width, level - 1, offset + field_index * stride) + }; + format!("F{field_index} = {value}") + }) + .collect::>() + .join(", "); + + format!("new TupleLevel{level} {{ {assignments} }}") +} + +fn sroa_field_path(field_index: usize, depth: usize) -> String { + (0..depth) + .map(|_| format!("F{field_index}")) + .collect::>() + .join(".") +} + +proptest! { + #![proptest_config(ProptestConfig::with_cases(50))] + + #[test] + fn sroa_preserves_semantics(source in sroa_tuple_local_pattern()) { + crate::test_utils::check_semantic_equivalence(&source); + } +} diff --git a/source/compiler/qsc_fir_transforms/src/sroa/tests.rs b/source/compiler/qsc_fir_transforms/src/sroa/tests.rs new file mode 100644 index 0000000000..8a264b1f19 --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/sroa/tests.rs @@ -0,0 +1,938 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use super::*; +use crate::test_utils::{PipelineStage, compile_and_run_pipeline_to}; +use expect_test::{Expect, expect}; +use indoc::indoc; +use qsc_fir::assigner::Assigner; +use qsc_fir::fir::{ + BinOp, CallableImpl, ExprKind, ItemKind, Mutability, PackageLookup, PatKind, Res, StmtKind, +}; +use rustc_hash::FxHashMap; + +fn check(source: &str, expect: &Expect) { + let (mut store, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::UdtErase); + let mut assigner = Assigner::from_package(store.get(pkg_id)); + sroa(&mut store, pkg_id, &mut assigner); + let result = extract_result(&store, pkg_id); + expect.assert_eq(&result); +} + +fn run_real_pipeline_to_sroa(source: &str) -> (PackageStore, PackageId) { + compile_and_run_pipeline_to(source, PipelineStage::Sroa) +} + +/// Compiles Q# source through the full FIR pipeline, then generates QIR via +/// partial evaluation and codegen. Uses Adaptive + `IntegerComputations` +/// capabilities so that Result-comparison programs can be lowered. +fn generate_qir(source: &str) -> String { + use qsc_codegen::qir::fir_to_qir; + use qsc_data_structures::target::TargetCapabilityFlags; + use qsc_partial_eval::ProgramEntry; + + let capabilities = TargetCapabilityFlags::Adaptive | TargetCapabilityFlags::IntegerComputations; + let (store, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::Full); + let package = store.get(pkg_id); + let entry = ProgramEntry { + exec_graph: package.entry_exec_graph.clone(), + expr: ( + pkg_id, + package + .entry + .expect("package must have an entry expression"), + ) + .into(), + }; + let compute_properties = qsc_rca::Analyzer::init(&store, capabilities).analyze_all(); + fir_to_qir(&store, capabilities, &compute_properties, &entry).expect("QIR generation failed") +} + +fn extract_result(store: &PackageStore, pkg_id: PackageId) -> String { + let package = store.get(pkg_id); + let reachable = crate::reachability::collect_reachable_from_entry(store, pkg_id); + let mut entries: Vec = Vec::new(); + for store_id in &reachable { + if store_id.package != pkg_id { + continue; + } + let item = package.get_item(store_id.item); + if let ItemKind::Callable(decl) = &item.kind { + let mut lines = Vec::new(); + lines.push(format!( + "Callable {}: input={}", + decl.name.name, + format_pat(package, decl.input) + )); + if let CallableImpl::Spec(spec) = &decl.implementation { + let block = package.get_block(spec.body.block); + for &stmt_id in &block.stmts { + let stmt = package.get_stmt(stmt_id); + if let StmtKind::Local(mutability, pat_id, _) = &stmt.kind { + let mut_str = if matches!(mutability, Mutability::Mutable) { + "mutable " + } else { + "" + }; + lines.push(format!( + " local: {}{}", + mut_str, + format_pat(package, *pat_id) + )); + } + } + } + entries.push(lines.join("\n")); + } + } + entries.sort(); + entries.join("\n") +} + +fn format_pat(package: &qsc_fir::fir::Package, pat_id: PatId) -> String { + let pat = package.get_pat(pat_id); + match &pat.kind { + PatKind::Bind(ident) => format!("Bind({}: {})", ident.name, pat.ty), + PatKind::Tuple(sub_pats) => { + let subs: Vec = sub_pats.iter().map(|&id| format_pat(package, id)).collect(); + format!("Tuple({})", subs.join(", ")) + } + PatKind::Discard => format!("Discard({})", pat.ty), + } +} + +fn local_names(package: &qsc_fir::fir::Package) -> FxHashMap { + package + .pats + .values() + .filter_map(|pat| match &pat.kind { + PatKind::Bind(ident) => Some((ident.id, ident.name.to_string())), + PatKind::Tuple(_) | PatKind::Discard => None, + }) + .collect() +} + +fn local_name(names: &FxHashMap, local_id: LocalVarId) -> String { + names + .get(&local_id) + .cloned() + .unwrap_or_else(|| format!("<{local_id:?}>")) +} + +fn var_local_name( + package: &qsc_fir::fir::Package, + names: &FxHashMap, + expr_id: ExprId, +) -> Option { + let expr = package.get_expr(expr_id); + match &expr.kind { + ExprKind::Var(Res::Local(local_id), _) => Some(local_name(names, *local_id)), + _ => None, + } +} + +fn collect_eq_pairs_and_invalid_fields(source: &str) -> (Vec<(String, String)>, Vec) { + let (store, pkg_id) = run_real_pipeline_to_sroa(source); + let package = store.get(pkg_id); + let names = local_names(package); + let reachable = crate::reachability::collect_reachable_from_entry(&store, pkg_id); + + let mut eq_pairs = Vec::new(); + let mut invalid_fields = Vec::new(); + + for store_id in &reachable { + if store_id.package != pkg_id { + continue; + } + let item = package.get_item(store_id.item); + if let ItemKind::Callable(decl) = &item.kind { + crate::walk_utils::for_each_expr_in_callable_impl( + package, + &decl.implementation, + &mut |expr_id, expr| match &expr.kind { + ExprKind::BinOp(BinOp::Eq, lhs_id, rhs_id) => { + if let (Some(lhs_name), Some(rhs_name)) = ( + var_local_name(package, &names, *lhs_id), + var_local_name(package, &names, *rhs_id), + ) { + eq_pairs.push((lhs_name, rhs_name)); + } + } + ExprKind::Field(inner_id, _) => { + let inner = package.get_expr(*inner_id); + if !matches!(inner.ty, qsc_fir::ty::Ty::Tuple(_)) { + invalid_fields.push(format!( + "Expr {expr_id} targets non-tuple {inner_id} with type {}", + inner.ty + )); + } + } + _ => {} + }, + ); + } + } + + eq_pairs.sort(); + invalid_fields.sort(); + (eq_pairs, invalid_fields) +} + +const SHARED_VAR_TUPLE_COMPARE_SOURCE: &str = "operation Main() : Bool { + use (q0, q1) = (Qubit(), Qubit()); + let pair = (M(q0), M(q1)); + pair == pair + }"; + +#[test] +fn struct_fields_decompose() { + check( + "struct Pair { X : Int, Y : Int } + function Main() : Int { + let p = new Pair { X = 1, Y = 2 }; + p.X + p.Y + }", + &expect![[r#" + Callable Main: input=Tuple() + local: Tuple(Bind(p_0: Int), Bind(p_1: Int))"#]], + ); +} + +#[test] +fn mutable_struct_fields_decompose() { + check( + "struct Pair { X : Int, Y : Int } + function Main() : Int { + mutable p = new Pair { X = 1, Y = 2 }; + let x = p.X; + let y = p.Y; + x + y + }", + &expect![[r#" + Callable Main: input=Tuple() + local: mutable Tuple(Bind(p_0: Int), Bind(p_1: Int)) + local: Bind(x: Int) + local: Bind(y: Int)"#]], + ); +} + +#[test] +fn whole_value_use_skips_decomposition() { + check( + "struct Pair { X : Int, Y : Int } + function Foo(p : Pair) : Int { p.X } + function Main() : Int { + let p = new Pair { X = 1, Y = 2 }; + Foo(p) + }", + &expect![[r#" + Callable Foo: input=Bind(p: (Int, Int)) + Callable Main: input=Tuple() + local: Bind(p: (Int, Int))"#]], + ); +} + +#[test] +fn triple_struct_decomposes() { + check( + "struct Triple { A : Int, B : Int, C : Int } + function Main() : Int { + let t = new Triple { A = 1, B = 2, C = 3 }; + t.A + t.B + t.C + }", + &expect![[r#" + Callable Main: input=Tuple() + local: Tuple(Bind(t_0: Int), Bind(t_1: Int), Bind(t_2: Int))"#]], + ); +} + +#[test] +fn nested_struct_field_access() { + // After iterative SROA, both the outer and inner tuples decompose + // since the inner tuple's only use is a field access. + check( + "struct Inner { X : Int, Y : Int } + struct Outer { P : Inner, Z : Int } + function Main() : Int { + let o = new Outer { P = new Inner { X = 1, Y = 2 }, Z = 3 }; + o.P.Y + }", + &expect![[r#" + Callable Main: input=Tuple() + local: Tuple(Tuple(Bind(o_0_0: Int), Bind(o_0_1: Int)), Bind(o_1: Int))"#]], + ); +} + +#[test] +fn tuple_used_in_both_field_and_whole_context() { + // When a struct is used both via field access AND as a whole value + // (e.g. returned), it must NOT be decomposed. + check( + "struct Pair { X : Int, Y : Int } + function Main() : Pair { + let p = new Pair { X = 1, Y = 2 }; + let x = p.X; + p + }", + &expect![[r#" + Callable Main: input=Tuple() + local: Bind(p: (Int, Int)) + local: Bind(x: Int)"#]], + ); +} + +#[test] +fn nested_tuple_depth_two() { + // Outer struct with two inner structs: iterative SROA decomposes + // both the outer and inner tuples since all uses are field-only. + check( + "struct Inner { A : Int, B : Int } + struct Outer { Left : Inner, Right : Inner } + function Main() : Int { + let o = new Outer { + Left = new Inner { A = 1, B = 2 }, + Right = new Inner { A = 3, B = 4 } + }; + o.Left.A + o.Right.B + }", + &expect![[r#" + Callable Main: input=Tuple() + local: Tuple(Tuple(Bind(o_0_0: Int), Bind(o_0_1: Int)), Tuple(Bind(o_1_0: Int), Bind(o_1_1: Int)))"#]], + ); +} + +#[test] +fn empty_tuple_local() { + // `let u = ();` — Unit is an empty tuple; should not panic, not decomposed. + check( + "function Main() : Unit { + let u = (); + }", + &expect![[r#" + Callable Main: input=Tuple() + local: Bind(u: Unit)"#]], + ); +} + +#[test] +fn single_field_struct_field_access() { + // Single-field struct: after UDT erasure the binding type is still + // a one-element tuple internally, so SROA decomposes it. + check( + "struct Wrapper { Val : Int } + function Main() : Int { + let w = new Wrapper { Val = 42 }; + w.Val + }", + &expect![[r#" + Callable Main: input=Tuple() + local: Tuple(Bind(w_0: Int))"#]], + ); +} + +#[test] +fn mutable_tuple_partial_field_modification() { + // After UDT erasure, `set t w/= A <- 10` becomes a whole assignment + // `set t = (10, t.1, t.2)`. SROA now recognizes this Assign-Tuple + // pattern as decomposable and splits it into per-element assignments. + check( + "struct Triple { A : Int, B : Int, C : Int } + function Main() : Int { + mutable t = new Triple { A = 1, B = 2, C = 3 }; + t w/= A <- 10; + t.A + t.B + t.C + }", + &expect![[r#" + Callable Main: input=Tuple() + local: mutable Tuple(Bind(t_0: Int), Bind(t_1: Int), Bind(t_2: Int))"#]], + ); +} + +#[test] +fn tuple_passed_to_function_as_arg() { + // When a struct is passed as a whole argument to another function, + // it should NOT be decomposed (whole-value use). + check( + "struct Pair { X : Int, Y : Int } + function Sum(p : Pair) : Int { p.X + p.Y } + function Main() : Int { + let p = new Pair { X = 1, Y = 2 }; + Sum(p) + }", + &expect![[r#" + Callable Main: input=Tuple() + local: Bind(p: (Int, Int)) + Callable Sum: input=Bind(p: (Int, Int))"#]], + ); +} + +#[test] +fn sroa_candidate_in_while_loop_decomposes() { + // Struct binding inside a while loop body: SROA should handle + // control-flow nested bindings and decompose the nested local. + let source = "struct Pair { A : Int, B : Int } + function Main() : Int { + mutable sum = 0; + mutable i = 0; + while i < 3 { + let p = new Pair { A = i, B = i + 1 }; + sum += p.A + p.B; + i += 1; + } + sum + }"; + check( + source, + &expect![[r#" + Callable Main: input=Tuple() + local: mutable Bind(sum: Int) + local: mutable Bind(i: Int)"#]], + ); + + let (mut store, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::UdtErase); + let mut assigner = Assigner::from_package(store.get(pkg_id)); + sroa(&mut store, pkg_id, &mut assigner); + let local_patterns = collect_local_patterns_recursive(store.get(pkg_id)); + assert!( + local_patterns + .iter() + .any(|pat| pat == "Tuple(Bind(p_0: Int), Bind(p_1: Int))"), + "loop-local Pair binding should be decomposed, got {local_patterns:?}" + ); + assert!( + !local_patterns + .iter() + .any(|pat| pat == "Bind(p: (Int, Int))"), + "loop-local Pair binding should not remain whole, got {local_patterns:?}" + ); +} + +fn collect_local_patterns_recursive(package: &qsc_fir::fir::Package) -> Vec { + let mut patterns = Vec::new(); + for item in package.items.values() { + let ItemKind::Callable(decl) = &item.kind else { + continue; + }; + if let CallableImpl::Spec(spec) = &decl.implementation { + collect_local_patterns_in_block(package, spec.body.block, &mut patterns); + } + } + patterns.sort(); + patterns +} + +fn collect_local_patterns_in_block( + package: &qsc_fir::fir::Package, + block_id: qsc_fir::fir::BlockId, + patterns: &mut Vec, +) { + for &stmt_id in &package.get_block(block_id).stmts { + let stmt = package.get_stmt(stmt_id); + match &stmt.kind { + StmtKind::Expr(expr) | StmtKind::Semi(expr) => { + collect_local_patterns_in_expr(package, *expr, patterns); + } + StmtKind::Local(_, pat_id, expr) => { + patterns.push(format_pat(package, *pat_id)); + collect_local_patterns_in_expr(package, *expr, patterns); + } + StmtKind::Item(_) => {} + } + } +} + +fn collect_local_patterns_in_expr( + package: &qsc_fir::fir::Package, + expr_id: ExprId, + patterns: &mut Vec, +) { + match &package.get_expr(expr_id).kind { + ExprKind::Array(exprs) | ExprKind::ArrayLit(exprs) | ExprKind::Tuple(exprs) => { + for &expr in exprs { + collect_local_patterns_in_expr(package, expr, patterns); + } + } + ExprKind::ArrayRepeat(item, size) + | ExprKind::Assign(item, size) + | ExprKind::AssignOp(_, item, size) + | ExprKind::BinOp(_, item, size) + | ExprKind::Call(item, size) + | ExprKind::Index(item, size) + | ExprKind::AssignField(item, _, size) + | ExprKind::UpdateField(item, _, size) => { + collect_local_patterns_in_expr(package, *item, patterns); + collect_local_patterns_in_expr(package, *size, patterns); + } + ExprKind::AssignIndex(array, index, value) | ExprKind::UpdateIndex(array, index, value) => { + collect_local_patterns_in_expr(package, *array, patterns); + collect_local_patterns_in_expr(package, *index, patterns); + collect_local_patterns_in_expr(package, *value, patterns); + } + ExprKind::Block(block) => collect_local_patterns_in_block(package, *block, patterns), + ExprKind::Closure(_, _) | ExprKind::Hole | ExprKind::Lit(_) | ExprKind::Var(_, _) => {} + ExprKind::Fail(expr) + | ExprKind::Field(expr, _) + | ExprKind::Return(expr) + | ExprKind::UnOp(_, expr) => collect_local_patterns_in_expr(package, *expr, patterns), + ExprKind::If(cond, body, otherwise) => { + collect_local_patterns_in_expr(package, *cond, patterns); + collect_local_patterns_in_expr(package, *body, patterns); + if let Some(otherwise) = otherwise { + collect_local_patterns_in_expr(package, *otherwise, patterns); + } + } + ExprKind::Range(start, step, end) => { + for expr in [start, step, end].into_iter().flatten() { + collect_local_patterns_in_expr(package, *expr, patterns); + } + } + ExprKind::Struct(_, copy, fields) => { + if let Some(copy) = copy { + collect_local_patterns_in_expr(package, *copy, patterns); + } + for field in fields { + collect_local_patterns_in_expr(package, field.value, patterns); + } + } + ExprKind::String(components) => { + for component in components { + if let qsc_fir::fir::StringComponent::Expr(expr) = component { + collect_local_patterns_in_expr(package, *expr, patterns); + } + } + } + ExprKind::While(cond, block) => { + collect_local_patterns_in_expr(package, *cond, patterns); + collect_local_patterns_in_block(package, *block, patterns); + } + } +} + +#[test] +fn sroa_nested_struct_outer_decomposed_inner_field_access() { + // Inner/Outer struct with multi-level field access: o.I.X and o.I.Y. + // Iterative SROA decomposes both levels since all inner uses are + // field-only accesses. + check( + "struct Inner { X : Int, Y : Int } + struct Outer { I : Inner, Z : Bool } + function Main() : Int { + let o = new Outer { I = new Inner { X = 1, Y = 2 }, Z = true }; + o.I.X + o.I.Y + }", + &expect![[r#" + Callable Main: input=Tuple() + local: Tuple(Tuple(Bind(o_0_0: Int), Bind(o_0_1: Int)), Bind(o_1: Bool))"#]], + ); +} + +#[test] +fn nested_tuple_fully_flattened() { + // `((Int, Int), Bool)` with all field-only uses decomposes to three + // scalar bindings via iterative SROA. + check( + "struct Inner { A : Int, B : Int } + struct Outer { I : Inner, Z : Bool } + function Main() : Int { + let o = new Outer { I = new Inner { A = 10, B = 20 }, Z = false }; + o.I.A + o.I.B + }", + &expect![[r#" + Callable Main: input=Tuple() + local: Tuple(Tuple(Bind(o_0_0: Int), Bind(o_0_1: Int)), Bind(o_1: Bool))"#]], + ); +} + +#[test] +fn mutable_tuple_literal_reassignment_decomposes() { + // `set x = (3, 4)` with a tuple literal RHS is recognized as + // decomposable, so `x` is decomposed into `x_0`, `x_1`. + check( + "struct Pair { A : Int, B : Int } + function Main() : Int { + mutable x = new Pair { A = 1, B = 2 }; + x = new Pair { A = 3, B = 4 }; + x.A + x.B + }", + &expect![[r#" + Callable Main: input=Tuple() + local: mutable Tuple(Bind(x_0: Int), Bind(x_1: Int))"#]], + ); +} + +#[test] +fn mutable_tuple_var_reassignment_no_decompose() { + // `set x = other` is NOT a tuple-literal RHS, so `x` is NOT decomposed. + check( + "struct Pair { A : Int, B : Int } + function Main() : Int { + let other = new Pair { A = 5, B = 6 }; + mutable x = new Pair { A = 1, B = 2 }; + x = other; + x.A + }", + &expect![[r#" + Callable Main: input=Tuple() + local: Bind(other: (Int, Int)) + local: mutable Bind(x: (Int, Int))"#]], + ); +} + +#[test] +fn sroa_tuple_compare() { + // Verify that tuple comparison with Result values is lowered by + // tuple_compare_lower, then SROA can decompose the tuple bindings, + // and the full pipeline produces valid QIR. + let qir = generate_qir( + "operation Main() : Bool { + use (q0, q1) = (Qubit(), Qubit()); + let (r0, r1) = (M(q0), M(q1)); + (r0, r1) == (Zero, Zero) + }", + ); + assert!( + !qir.is_empty(), + "QIR generation should succeed for tuple comparison after SROA" + ); +} + +#[test] +fn sroa_tuple_compare_shared_var_rewrites_all_eq_operands_after_pipeline_sroa() { + let (eq_pairs, invalid_fields) = + collect_eq_pairs_and_invalid_fields(SHARED_VAR_TUPLE_COMPARE_SOURCE); + + assert!( + invalid_fields.is_empty(), + "post-SROA should not leave field accesses on non-tuples:\n{}", + invalid_fields.join("\n") + ); + assert_eq!( + eq_pairs, + vec![ + ("pair_0".to_string(), "pair_0".to_string()), + ("pair_1".to_string(), "pair_1".to_string()), + ] + ); +} + +#[test] +fn multi_index_assign_field_decomposes_iteratively() { + let source = indoc! {" + namespace Test { + newtype Foo = (a: Int, (b: Double, c: Bool)); + @EntryPoint() + function Main() : Unit { + mutable f = Foo(1, (2.0, true)); + f w/= b <- 3.14; + } + } + "}; + let (mut store, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::TupleCompLower); + let mut assigner = Assigner::from_package(store.get(pkg_id)); + sroa(&mut store, pkg_id, &mut assigner); + + let package = store.get(pkg_id); + let names = local_names(package); + let reachable = crate::reachability::collect_reachable_from_entry(&store, pkg_id); + let mut stale_uses = Vec::new(); + let mut assignments = Vec::new(); + + for store_id in &reachable { + if store_id.package != pkg_id { + continue; + } + let item = package.get_item(store_id.item); + let ItemKind::Callable(decl) = &item.kind else { + continue; + }; + crate::walk_utils::for_each_expr_in_callable_impl( + package, + &decl.implementation, + &mut |_expr_id, expr| match &expr.kind { + ExprKind::Assign(lhs_id, _) => { + if let Some(name) = var_local_name(package, &names, *lhs_id) { + assignments.push(name); + } + } + ExprKind::AssignField(record_id, Field::Path(path), _) => { + if let Some(name) = var_local_name(package, &names, *record_id) { + stale_uses.push(format!("{name}::{:?}", path.indices)); + } + } + _ => {} + }, + ); + } + + assignments.sort(); + stale_uses.sort(); + assert_eq!( + assignments, + vec!["f_0".to_string(), "f_1_0".to_string(), "f_1_1".to_string(),] + ); + assert!( + stale_uses.is_empty(), + "nested AssignField uses should be fully rewritten after iterative SROA: {stale_uses:?}" + ); +} + +#[test] +fn sroa_tuple_compare_shared_var_generates_qir() { + let qir = generate_qir(SHARED_VAR_TUPLE_COMPARE_SOURCE); + assert!( + !qir.is_empty(), + "QIR generation should succeed for tuple comparisons on a shared tuple local" + ); +} + +#[test] +fn higher_order_tuple_field_projection_still_decomposes() { + // A struct local whose only uses are field projections should still + // decompose even when those projections feed a higher-order call that + // defunctionalization specializes. + check( + "struct Pair { X : Int, Y : Int } + function Apply(f : (Int, Int) -> Int, x : Int, y : Int) : Int { f(x, y) } + function Main() : Int { + let p = new Pair { X = 1, Y = 2 }; + Apply((a, b) -> a + b, p.X, p.Y) + }", + &expect![[r#" + Callable : input=Tuple(Tuple(Bind(a: Int), Bind(b: Int))) + Callable Apply{closure}: input=Tuple(Bind(x: Int), Bind(y: Int)) + Callable Main: input=Tuple() + local: Tuple(Bind(p_0: Int), Bind(p_1: Int))"#]], + ); +} + +#[test] +fn nested_tuple_depth_three_fully_flattened() { + // Depth-3 nested tuple with all field-only access: iterative SROA + // should flatten all levels. + check( + "struct Inner { X : Int, Y : Int } + struct Mid { I : Inner, Z : Int } + struct Deep { M : Mid, W : Int } + function Main() : Int { + let d = new Deep { + M = new Mid { I = new Inner { X = 1, Y = 2 }, Z = 3 }, + W = 4 + }; + d.M.I.X + d.M.I.Y + d.M.Z + d.W + }", + &expect![[r#" + Callable Main: input=Tuple() + local: Tuple(Tuple(Tuple(Bind(d_0_0_0: Int), Bind(d_0_0_1: Int)), Bind(d_0_1: Int)), Bind(d_1: Int))"#]], + ); +} + +#[test] +fn struct_fields_decompose_in_adj_and_ctl_specs() { + let source = "struct Pair { X : Double, Y : Double } + operation Foo(q : Qubit) : Unit is Adj + Ctl { + let p = new Pair { X = 1.0, Y = 2.0 }; + Rx(p.X, q); + Ry(p.Y, q); + } + operation Main() : Unit { + use q = Qubit(); + use ctrl = Qubit(); + Foo(q); + Adjoint Foo(q); + Controlled Foo([ctrl], q); + }"; + let (mut store, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::UdtErase); + let mut assigner = Assigner::from_package(store.get(pkg_id)); + sroa(&mut store, pkg_id, &mut assigner); + let result = extract_result_all_specs(&store, pkg_id); + expect![[r#" + Callable Foo: input=Bind(q: Qubit) + body: Tuple(Bind(p_0: Double), Bind(p_1: Double)) + adj: Tuple(Bind(p_0: Double), Bind(p_1: Double)) + ctl: Tuple(Bind(p_0: Double), Bind(p_1: Double)) + ctl_adj: Tuple(Bind(p_0: Double), Bind(p_1: Double)) + Callable Main: input=Tuple() + body: Bind(q: Qubit) + body: Bind(ctrl: Qubit)"#]] + .assert_eq(&result); +} + +/// Like [`extract_result`] but labels locals by specialization kind, so tests +/// can verify SROA decomposition in non-body specializations. +fn extract_result_all_specs(store: &PackageStore, pkg_id: PackageId) -> String { + let package = store.get(pkg_id); + let reachable = crate::reachability::collect_reachable_from_entry(store, pkg_id); + let mut entries: Vec = Vec::new(); + for store_id in &reachable { + if store_id.package != pkg_id { + continue; + } + let item = package.get_item(store_id.item); + if let ItemKind::Callable(decl) = &item.kind { + let mut lines = Vec::new(); + lines.push(format!( + "Callable {}: input={}", + decl.name.name, + format_pat(package, decl.input) + )); + if let CallableImpl::Spec(spec_impl) = &decl.implementation { + push_spec_locals(package, "body", &spec_impl.body, &mut lines); + if let Some(adj) = &spec_impl.adj { + push_spec_locals(package, "adj", adj, &mut lines); + } + if let Some(ctl) = &spec_impl.ctl { + push_spec_locals(package, "ctl", ctl, &mut lines); + } + if let Some(ctl_adj) = &spec_impl.ctl_adj { + push_spec_locals(package, "ctl_adj", ctl_adj, &mut lines); + } + } + entries.push(lines.join("\n")); + } + } + entries.sort(); + entries.join("\n") +} + +fn push_spec_locals( + package: &qsc_fir::fir::Package, + label: &str, + spec: &qsc_fir::fir::SpecDecl, + lines: &mut Vec, +) { + let block = package.get_block(spec.block); + for &stmt_id in &block.stmts { + let stmt = package.get_stmt(stmt_id); + if let StmtKind::Local(mutability, pat_id, _) = &stmt.kind { + let mut_str = if matches!(mutability, Mutability::Mutable) { + "mutable " + } else { + "" + }; + lines.push(format!( + " {label}: {mut_str}{}", + format_pat(package, *pat_id) + )); + } + } +} + +#[test] +fn sroa_is_idempotent() { + let source = "struct Pair { X : Int, Y : Int } + function Main() : Int { + let p = new Pair { X = 1, Y = 2 }; + p.X + p.Y + }"; + let (mut store, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::Sroa); + let first = crate::pretty::write_package_qsharp(&store, pkg_id); + let mut assigner = Assigner::from_package(store.get(pkg_id)); + sroa(&mut store, pkg_id, &mut assigner); + let second = crate::pretty::write_package_qsharp(&store, pkg_id); + assert_eq!(first, second, "sroa should be idempotent"); +} + +fn render_before_after_sroa(source: &str) -> (String, String) { + let (mut store, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::TupleCompLower); + let before = crate::pretty::write_package_qsharp(&store, pkg_id); + let mut assigner = Assigner::from_package(store.get(pkg_id)); + sroa(&mut store, pkg_id, &mut assigner); + let after = crate::pretty::write_package_qsharp(&store, pkg_id); + (before, after) +} + +fn check_before_after_sroa(source: &str, expect: &Expect) { + let (before, after) = render_before_after_sroa(source); + expect.assert_eq(&format!("BEFORE:\n{before}\nAFTER:\n{after}")); +} + +#[test] +fn before_after_struct_field_decomposition() { + check_before_after_sroa( + "struct Pair { X : Int, Y : Int } + function Main() : Int { + let p = new Pair { X = 1, Y = 2 }; + p.X + p.Y + }", + &expect![[r#" + BEFORE: + // namespace test + newtype Pair = (Int, Int); + function Main() : Int { + body { + let p : (Int, Int) = (1, 2); + p::Item < 0 > + p::Item < 1 > + } + } + // entry + Main() + + AFTER: + // namespace test + newtype Pair = (Int, Int); + function Main() : Int { + body { + let (p_0 : Int, p_1 : Int) = (1, 2); + p_0 + p_1 + } + } + // entry + Main() + "#]], // snapshot populated by UPDATE_EXPECT=1 + ); +} + +#[test] +fn pretty_print_after_sroa_is_non_empty() { + let source = indoc! {r#" + namespace Test { + @EntryPoint() + function Main() : Int { + let pair = (5, 6); + let (a, b) = pair; + a + b + } + } + "#}; + let (store, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::Sroa); + let rendered = crate::pretty::write_package_qsharp(&store, pkg_id); + // After SROA the rendered Q# uses split tuple bindings and `body { ... }` + // spec syntax. Verify the render produces non-empty output. + assert!( + !rendered.is_empty(), + "pretty-printed Q# after SROA should not be empty" + ); +} + +#[test] +fn unreachable_callable_tuple_local_behavior() { + // Reachable Foo has a tuple local, Dead also has one. + // Document whether Dead's tuple local is scalarized. + // The `check` helper only extracts reachable callables via + // `collect_reachable_from_entry`, so this captures reachable-only output. + check( + indoc! {" + namespace Test { + @EntryPoint() + operation Main() : Int { + Foo() + } + operation Foo() : Int { + let t = (1, 2); + let (a, b) = t; + a + b + } + operation Dead() : Int { + let t = (3, 4); + let (a, b) = t; + a * b + } + } + "}, + &expect![[r#" + Callable Foo: input=Tuple() + local: Bind(t: (Int, Int)) + local: Tuple(Bind(a: Int), Bind(b: Int)) + Callable Main: input=Tuple()"#]], + ); +} diff --git a/source/compiler/qsc_fir_transforms/src/test_utils.rs b/source/compiler/qsc_fir_transforms/src/test_utils.rs new file mode 100644 index 0000000000..9129828fa7 --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/test_utils.rs @@ -0,0 +1,688 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! Shared test helpers for the `qsc_fir_transforms` crate. +//! +//! Provides compilation and snapshot utilities used across transform test +//! modules. Gated behind `#[cfg(any(test, feature = "testutil"))]`. + +use qsc_data_structures::{ + language_features::LanguageFeatures, source::SourceMap, target::TargetCapabilityFlags, +}; +use qsc_fir::fir::{ + self, CallableImpl, ExprId, ExprKind, ItemKind, LocalVarId, Package, PackageLookup, PatKind, + Res, SpecDecl, StmtId, StmtKind, +}; +use qsc_frontend::compile::{self as frontend_compile, PackageStore as HirPackageStore}; +use qsc_hir::hir::PackageId; +use qsc_passes::{PackageType, lower_hir_to_fir, run_core_passes, run_default_passes}; + +#[cfg(test)] +use qsc_lowerer::map_hir_package_to_fir; + +pub(crate) use crate::PipelineStage; + +fn format_errors(errors: &[T]) -> String { + errors + .iter() + .map(ToString::to_string) + .collect::>() + .join("\n") +} + +pub(crate) fn assert_no_compile_errors(context: &str, errors: &[frontend_compile::Error]) { + let error_messages = errors + .iter() + .map(|error| format!("{error:?}")) + .collect::>() + .join("\n"); + assert!( + errors.is_empty(), + "{context} has Q# compilation errors:\n{error_messages}" + ); +} + +pub fn assert_no_pipeline_errors(context: &str, errors: &[crate::PipelineError]) { + let error_messages = format_errors(errors); + assert!( + errors.is_empty(), + "{context} produced FIR transform pipeline errors:\n{error_messages}" + ); +} + +/// Sets up an HIR package store containing core + std libraries with default +/// passes applied, using the given target capabilities. +#[must_use] +pub fn package_store_with_stdlib(capabilities: TargetCapabilityFlags) -> HirPackageStore { + let mut core_unit = frontend_compile::core(); + assert_no_compile_errors("core library", &core_unit.errors); + let core_errors = run_core_passes(&mut core_unit); + assert!( + core_errors.is_empty(), + "core library has compilation errors" + ); + let mut store = HirPackageStore::new(core_unit); + + let mut std_unit = frontend_compile::std(&store, capabilities); + assert_no_compile_errors("std library", &std_unit.errors); + let std_errors = run_default_passes(store.core(), &mut std_unit, PackageType::Lib); + assert!(std_errors.is_empty(), "std library has compilation errors"); + store.insert(std_unit); + + store +} + +/// Convenience wrapper around [`package_store_with_stdlib`] that passes +/// [`TargetCapabilityFlags::empty()`]. +#[must_use] +pub fn package_store_with_stdlib_default() -> HirPackageStore { + package_store_with_stdlib(TargetCapabilityFlags::empty()) +} + +/// Compiles Q# source through core+std → HIR passes → FIR lowering. +/// +/// Returns a FIR store with no transforms applied. Uses default (empty) +/// target capabilities. +#[must_use] +pub fn compile_to_fir(source: &str) -> (fir::PackageStore, fir::PackageId) { + compile_to_fir_with_capabilities(source, TargetCapabilityFlags::empty()) +} + +/// Compiles Q# source through core+std → HIR passes → FIR lowering using the +/// given target capabilities. +/// +/// Returns a FIR store with no transforms applied. +#[must_use] +pub fn compile_to_fir_with_capabilities( + source: &str, + capabilities: TargetCapabilityFlags, +) -> (fir::PackageStore, fir::PackageId) { + let mut store = package_store_with_stdlib(capabilities); + let std_id = PackageId::CORE.successor(); + let sources = SourceMap::new(vec![("test.qs".into(), source.into())], None); + let mut unit = frontend_compile::compile( + &store, + &[(PackageId::CORE, None), (std_id, None)], + sources, + capabilities, + LanguageFeatures::default(), + ); + assert_no_compile_errors("user code", &unit.errors); + let pass_errors = run_default_passes(store.core(), &mut unit, PackageType::Exe); + assert!(pass_errors.is_empty(), "user code has compilation errors"); + let hir_package_id = store.insert(unit); + let (fir_store, fir_pkg_id, _) = lower_hir_to_fir(&store, hir_package_id); + (fir_store, fir_pkg_id) +} + +/// Compiles Q# source through core+std → HIR passes → FIR lowering → +/// monomorphization. +/// +/// Returns a monomorphized FIR store ready for defunctionalization or later +/// pipeline stages. Uses default (empty) target capabilities. +#[must_use] +pub fn compile_to_monomorphized_fir(source: &str) -> (fir::PackageStore, fir::PackageId) { + compile_to_monomorphized_fir_with_capabilities(source, TargetCapabilityFlags::empty()) +} + +/// Compiles Q# source through core+std → HIR passes → FIR lowering → +/// monomorphization using the given target capabilities. +/// +/// Returns a monomorphized FIR store ready for defunctionalization or later +/// pipeline stages. +#[must_use] +pub fn compile_to_monomorphized_fir_with_capabilities( + source: &str, + capabilities: TargetCapabilityFlags, +) -> (fir::PackageStore, fir::PackageId) { + let (mut store, pkg_id) = compile_to_fir_with_capabilities(source, capabilities); + let mut assigner = qsc_fir::assigner::Assigner::from_package(store.get(pkg_id)); + crate::monomorphize::monomorphize(&mut store, pkg_id, &mut assigner); + (store, pkg_id) +} + +/// Compiles Q# source through core+std → HIR passes → FIR lowering using an +/// explicit executable entry expression. +/// +/// Returns a FIR store with no transforms applied. +#[must_use] +pub fn compile_to_fir_with_entry(source: &str, entry: &str) -> (fir::PackageStore, fir::PackageId) { + let mut store = package_store_with_stdlib(TargetCapabilityFlags::empty()); + let std_id = PackageId::CORE.successor(); + let sources = SourceMap::new(vec![("test.qs".into(), source.into())], Some(entry.into())); + let mut unit = frontend_compile::compile( + &store, + &[(PackageId::CORE, None), (std_id, None)], + sources, + TargetCapabilityFlags::empty(), + LanguageFeatures::default(), + ); + assert_no_compile_errors("user code", &unit.errors); + let pass_errors = run_default_passes(store.core(), &mut unit, PackageType::Exe); + assert!(pass_errors.is_empty(), "user code has compilation errors"); + let hir_package_id = store.insert(unit); + let (fir_store, fir_pkg_id, _) = lower_hir_to_fir(&store, hir_package_id); + (fir_store, fir_pkg_id) +} + +/// Compiles Q# source and runs the FIR optimization pipeline up to the given +/// stage. +/// +/// # Panics +/// +/// Panics if compilation fails, or if the requested stage reaches +/// defunctionalization and the shared pipeline runner returns any errors. +#[allow(dead_code)] +pub(crate) fn compile_and_run_pipeline_to_with_errors( + source: &str, + stage: PipelineStage, +) -> (fir::PackageStore, fir::PackageId, Vec) { + let (mut store, pkg_id) = compile_to_fir(source); + let errors = crate::run_pipeline_to(&mut store, pkg_id, stage, &[]); + (store, pkg_id, errors) +} + +/// Compiles Q# source and runs the FIR optimization pipeline up to the given +/// stage, asserting that defunctionalization diagnostics stay empty once the +/// schedule reaches or passes that stage. +#[allow(dead_code)] +pub(crate) fn compile_and_run_pipeline_to( + source: &str, + stage: PipelineStage, +) -> (fir::PackageStore, fir::PackageId) { + let (store, pkg_id, errors) = compile_and_run_pipeline_to_with_errors(source, stage); + if matches!( + stage, + PipelineStage::Defunc + | PipelineStage::UdtErase + | PipelineStage::TupleCompLower + | PipelineStage::Sroa + | PipelineStage::ArgPromote + | PipelineStage::Gc + | PipelineStage::ItemDce + | PipelineStage::ExecGraphRebuild + | PipelineStage::Full + ) { + assert_no_pipeline_errors("compile_and_run_pipeline_to", &errors); + } + + (store, pkg_id) +} + +#[allow(dead_code)] +fn local_name(package: &Package, local_id: LocalVarId) -> Option<&str> { + package.pats.values().find_map(|pat| match &pat.kind { + PatKind::Bind(ident) if ident.id == local_id => Some(ident.name.as_ref()), + PatKind::Bind(_) | PatKind::Tuple(_) | PatKind::Discard => None, + }) +} + +#[allow(dead_code)] +fn callable_ref_short(package: &Package, pkg_id: fir::PackageId, expr_id: ExprId) -> String { + let expr = package.get_expr(expr_id); + match &expr.kind { + ExprKind::Var(Res::Item(item_id), _) if item_id.package == pkg_id => { + match &package.get_item(item_id.item).kind { + ItemKind::Callable(decl) => decl.name.name.to_string(), + _ => format!("Item({item_id})"), + } + } + ExprKind::Var(Res::Item(item_id), _) => format!("Item({item_id})"), + ExprKind::Var(Res::Local(local_id), _) => match local_name(package, *local_id) { + Some(name) => format!("Local({name})"), + None => format!("Local({local_id})"), + }, + ExprKind::UnOp(op, inner) => { + format!("{op}({})", callable_ref_short(package, pkg_id, *inner)) + } + _ => expr_kind_short(package, expr_id), + } +} + +#[allow(dead_code)] +fn expr_detail_short(package: &Package, pkg_id: fir::PackageId, expr_id: ExprId) -> String { + let expr = package.get_expr(expr_id); + match &expr.kind { + ExprKind::Call(callee, args) => { + let args_expr = package.get_expr(*args); + format!( + "Call({}, arg_ty={})", + callable_ref_short(package, pkg_id, *callee), + args_expr.ty + ) + } + _ => expr_kind_short(package, expr_id), + } +} + +#[allow(dead_code)] +fn push_spec_decl_summary( + package: &Package, + pkg_id: fir::PackageId, + label: &str, + spec: &SpecDecl, + lines: &mut Vec, +) { + let block = package.get_block(spec.block); + lines.push(format!(" {label}: block_ty={}", block.ty)); + for (index, stmt_id) in block.stmts.iter().enumerate() { + let stmt = package.get_stmt(*stmt_id); + let line = match &stmt.kind { + StmtKind::Expr(expr_id) => { + let expr = package.get_expr(*expr_id); + format!( + " [{index}] Expr ty={} {}", + expr.ty, + expr_detail_short(package, pkg_id, *expr_id) + ) + } + StmtKind::Semi(expr_id) => { + let expr = package.get_expr(*expr_id); + format!( + " [{index}] Semi ty={} {}", + expr.ty, + expr_detail_short(package, pkg_id, *expr_id) + ) + } + StmtKind::Local(_, pat_id, expr_id) => { + let pat = package.get_pat(*pat_id); + let expr = package.get_expr(*expr_id); + format!( + " [{index}] Local pat_ty={} init_ty={} {}", + pat.ty, + expr.ty, + expr_detail_short(package, pkg_id, *expr_id) + ) + } + StmtKind::Item(local_item_id) => format!(" [{index}] Item {local_item_id}"), + }; + lines.push(line); + } +} + +/// Extracts a deterministic summary of reachable callable signatures and body +/// shapes for the given package. +/// +/// Entries are sorted alphabetically before being joined so `expect_test` +/// snapshots remain stable across runs regardless of the iteration order of +/// the underlying reachable-set container. +#[allow(dead_code)] +pub(crate) fn extract_reachable_callable_details( + store: &fir::PackageStore, + pkg_id: fir::PackageId, +) -> String { + let package = store.get(pkg_id); + let reachable = crate::reachability::collect_reachable_from_entry(store, pkg_id); + + let mut entries = Vec::new(); + for store_id in &reachable { + if store_id.package != pkg_id { + continue; + } + let item = package.get_item(store_id.item); + if let ItemKind::Callable(decl) = &item.kind { + let pat = package.get_pat(decl.input); + let mut lines = vec![format!( + "callable {}: input_ty={}, output_ty={}", + decl.name.name, pat.ty, decl.output + )]; + + match &decl.implementation { + CallableImpl::Intrinsic => lines.push(" intrinsic".to_string()), + CallableImpl::SimulatableIntrinsic(spec) => { + push_spec_decl_summary(package, pkg_id, "simulatable", spec, &mut lines); + } + CallableImpl::Spec(spec_impl) => { + push_spec_decl_summary(package, pkg_id, "body", &spec_impl.body, &mut lines); + for (label, spec) in [ + ("adj", spec_impl.adj.as_ref()), + ("ctl", spec_impl.ctl.as_ref()), + ("ctl_adj", spec_impl.ctl_adj.as_ref()), + ] { + if let Some(spec) = spec { + push_spec_decl_summary(package, pkg_id, label, spec, &mut lines); + } + } + } + } + + entries.push(lines.join("\n")); + } + } + entries.sort(); + entries.join("\n") +} + +/// Asserts that the named callable body ends in an expression whose type +/// matches the enclosing block type. +pub fn assert_callable_body_terminal_expr_matches_block_type( + store: &fir::PackageStore, + pkg_id: fir::PackageId, + callable_name: &str, +) { + let package = store.get(pkg_id); + let item = package + .items + .values() + .find(|item| match &item.kind { + ItemKind::Callable(decl) => decl.name.name.as_ref() == callable_name, + _ => false, + }) + .expect("callable should exist"); + + let ItemKind::Callable(decl) = &item.kind else { + panic!("item should be callable"); + }; + let spec = match &decl.implementation { + CallableImpl::Spec(spec_impl) => &spec_impl.body, + CallableImpl::SimulatableIntrinsic(spec) => spec, + CallableImpl::Intrinsic => panic!("callable '{callable_name}' should have a body"), + }; + + let block = package.get_block(spec.block); + let last_stmt_id = *block + .stmts + .last() + .expect("callable body should not be empty"); + let last_stmt = package.get_stmt(last_stmt_id); + let StmtKind::Expr(expr_id) = last_stmt.kind else { + panic!( + "callable '{callable_name}' should end in an Expr stmt, got {:?}", + last_stmt.kind + ); + }; + let expr = package.get_expr(expr_id); + assert_eq!( + expr.ty, block.ty, + "callable '{callable_name}' trailing expr type should match block type" + ); +} + +/// Returns a short human-readable label for an expression kind. +/// +/// Used to annotate exec graph snapshot nodes for readability. +/// Includes sub-discriminant info for `BinOp`, `UnOp`, `AssignOp`, and `Lit`. +#[must_use] +pub fn expr_kind_short(package: &Package, expr_id: ExprId) -> String { + let expr = package.get_expr(expr_id); + match &expr.kind { + ExprKind::Array(items) => format!("Array(len={})", items.len()), + ExprKind::ArrayLit(items) => format!("ArrayLit(len={})", items.len()), + ExprKind::ArrayRepeat(_, _) => "ArrayRepeat".to_string(), + ExprKind::Assign(_, _) => "Assign".to_string(), + ExprKind::AssignField(_, _, _) => "AssignField".to_string(), + ExprKind::AssignIndex(_, _, _) => "AssignIndex".to_string(), + ExprKind::AssignOp(op, _, _) => format!("AssignOp({op:?})"), + ExprKind::BinOp(op, _, _) => format!("BinOp({op:?})"), + ExprKind::Block(_) => "Block".to_string(), + ExprKind::Call(_, _) => "Call".to_string(), + ExprKind::Closure(_, _) => "Closure".to_string(), + ExprKind::Fail(_) => "Fail".to_string(), + ExprKind::Field(_, _) => "Field".to_string(), + ExprKind::Hole => "Hole".to_string(), + ExprKind::If(_, _, _) => "If".to_string(), + ExprKind::Index(_, _) => "Index".to_string(), + ExprKind::Lit(lit) => format!("Lit({lit:?})"), + ExprKind::Range(_, _, _) => "Range".to_string(), + ExprKind::Return(_) => "Return".to_string(), + ExprKind::String(parts) => format!("String(parts={})", parts.len()), + ExprKind::Struct(_, _, _) => "Struct".to_string(), + ExprKind::Tuple(es) => format!("Tuple(len={})", es.len()), + ExprKind::UnOp(op, _) => format!("UnOp({op:?})"), + ExprKind::UpdateField(_, _, _) => "UpdateField".to_string(), + ExprKind::UpdateIndex(_, _, _) => "UpdateIndex".to_string(), + ExprKind::Var(_, _) => "Var".to_string(), + ExprKind::While(_, _) => "While".to_string(), + } +} + +/// Returns a short human-readable label for a statement kind. +/// +/// Used to annotate exec graph snapshot nodes for readability. +#[allow(dead_code)] +pub(crate) fn stmt_kind_short(package: &Package, stmt_id: StmtId) -> String { + let stmt = package.get_stmt(stmt_id); + match &stmt.kind { + StmtKind::Expr(_) => "Expr".to_string(), + StmtKind::Item(_) => "Item".to_string(), + StmtKind::Local(_, _, _) => "Local".to_string(), + StmtKind::Semi(_) => "Semi".to_string(), + } +} + +/// Evaluates the entry exec graph of the given FIR store with a fixed +/// simulator seed for determinism. Returns `Ok(value)` on success, or +/// `Err(error_string)` on evaluation failure. +#[cfg(test)] +#[allow(dead_code)] +pub(crate) fn try_eval_fir_entry( + store: &fir::PackageStore, + pkg_id: fir::PackageId, +) -> Result { + use qsc_eval::backend::{SparseSim, TracingBackend}; + use qsc_eval::output::GenericReceiver; + use qsc_fir::fir::ExecGraphConfig; + + let package = store.get(pkg_id); + let entry_graph = package.entry_exec_graph.clone(); + let mut env = qsc_eval::Env::default(); + let mut sim = SparseSim::new(); + let mut out = Vec::::new(); + let mut receiver = GenericReceiver::new(&mut out); + qsc_eval::eval( + pkg_id, + Some(42), + entry_graph, + ExecGraphConfig::NoDebug, + store, + &mut env, + &mut TracingBackend::no_tracer(&mut sim), + &mut receiver, + ) + .map_err(|(err, _frames)| format!("{err:?}")) +} + +/// Compiles Q# source to FIR using a single lowerer (matching the +/// `qsc_eval` test pattern), and evaluates the entry exec graph. +/// +/// The FIR has no transforms applied — this captures the original program +/// semantics. +#[cfg(test)] +#[allow(dead_code)] +pub(crate) fn eval_qsharp_original(source: &str) -> Result { + let mut lowerer = qsc_lowerer::Lowerer::new(); + let mut core = frontend_compile::core(); + run_core_passes(&mut core); + let fir_store = fir::PackageStore::new(); + let core_fir = lowerer.lower_package(&core.package, &fir_store); + let mut hir_store = HirPackageStore::new(core); + + let mut std = frontend_compile::std(&hir_store, TargetCapabilityFlags::empty()); + assert!(std.errors.is_empty()); + assert!(run_default_passes(hir_store.core(), &mut std, PackageType::Lib).is_empty()); + let std_fir = lowerer.lower_package(&std.package, &fir_store); + let std_id = hir_store.insert(std); + + let sources = SourceMap::new(vec![("test.qs".into(), source.into())], None); + let mut unit = frontend_compile::compile( + &hir_store, + &[(PackageId::CORE, None), (std_id, None)], + sources, + TargetCapabilityFlags::empty(), + LanguageFeatures::default(), + ); + assert!(unit.errors.is_empty(), "{:?}", unit.errors); + let pass_errors = run_default_passes(hir_store.core(), &mut unit, PackageType::Exe); + assert!(pass_errors.is_empty(), "{pass_errors:?}"); + let unit_fir = lowerer.lower_package(&unit.package, &fir_store); + let user_hir_id = hir_store.insert(unit); + + let mut fir_store = fir::PackageStore::new(); + fir_store.insert(map_hir_package_to_fir(PackageId::CORE), core_fir); + fir_store.insert(map_hir_package_to_fir(std_id), std_fir); + fir_store.insert(map_hir_package_to_fir(user_hir_id), unit_fir); + + try_eval_fir_entry(&fir_store, map_hir_package_to_fir(user_hir_id)) +} + +/// Compiles Q# source, runs the full FIR transform pipeline, and evaluates +/// the entry exec graph. +#[cfg(test)] +#[allow(dead_code)] +pub(crate) fn eval_qsharp_transformed(source: &str) -> Result { + let (store, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::Full); + try_eval_fir_entry(&store, pkg_id) +} + +/// Asserts semantic equivalence of a Q# program before and after the +/// full FIR transform pipeline. +/// +/// 1. Compiles the original Q# source (no transforms) and evaluates it to +/// get the expected return value. +/// 2. Compiles and runs the full FIR pipeline, then evaluates to get the +/// actual return value. +/// 3. Asserts the two results match (both succeed with equal values, or +/// both fail). +#[cfg(test)] +#[allow(dead_code)] +pub(crate) fn check_semantic_equivalence(source: &str) { + let expected = eval_qsharp_original(source); + let actual = eval_qsharp_transformed(source); + + match (&expected, &actual) { + (Ok(exp_val), Ok(act_val)) => { + assert_eq!( + exp_val, act_val, + "semantic equivalence violated: original returned {exp_val}, \ + transformed returned {act_val}" + ); + } + (Err(exp_err), Err(act_err)) => { + assert_eq!( + exp_err, act_err, + "semantic equivalence violated: original failed with {exp_err}, transformed failed with {act_err}" + ); + } + (Ok(exp_val), Err(err)) => { + panic!("original succeeded with {exp_val} but transformed failed: {err}"); + } + (Err(err), Ok(act_val)) => { + panic!("original failed with {err} but transformed succeeded with {act_val}"); + } + } +} + +#[cfg(test)] +mod tests { + use std::any::Any; + + use super::*; + + fn panic_message(panic: Box) -> String { + match panic.downcast::() { + Ok(message) => *message, + Err(panic) => match panic.downcast::<&str>() { + Ok(message) => (*message).to_string(), + Err(_) => "(non-string panic payload)".to_string(), + }, + } + } + + #[test] + fn staged_runner_with_errors_returns_defunctionalization_diagnostics() { + let source = r#" + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { op(q); } + operation Main() : Unit { + use q = Qubit(); + mutable op = H; + mutable n = 3; + while n > 0 { + op = X; + n -= 1; + } + ApplyOp(op, q); + } + "#; + + let (_store, _pkg_id, errors) = + compile_and_run_pipeline_to_with_errors(source, PipelineStage::Full); + + assert!( + !errors.is_empty(), + "expected defunctionalization diagnostics to be returned" + ); + let messages = errors + .iter() + .map(ToString::to_string) + .collect::>() + .join("\n"); + assert!( + messages.contains("callable argument could not be resolved statically"), + "unexpected diagnostics: {messages}" + ); + } + + #[test] + fn checked_staged_runner_panics_on_unexpected_defunctionalization_diagnostics() { + let source = r#" + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { op(q); } + operation Main() : Unit { + use q = Qubit(); + mutable op = H; + mutable n = 3; + while n > 0 { + op = X; + n -= 1; + } + ApplyOp(op, q); + } + "#; + + let panic = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + let _ = compile_and_run_pipeline_to(source, PipelineStage::Full); + })) + .expect_err("checked staged runner should panic on unexpected diagnostics"); + let message = panic_message(panic); + assert!( + message.contains("compile_and_run_pipeline_to produced FIR transform pipeline errors"), + "unexpected panic: {message}" + ); + assert!( + message.contains("callable argument could not be resolved statically"), + "unexpected panic: {message}" + ); + } + + #[test] + fn reachable_callable_details_report_body_shape() { + let source = r#" + namespace Test { + function Helper(x : Int) : Int { x + 1 } + + @EntryPoint() + function Main() : Int { + Helper(2) + } + } + "#; + + let (store, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::Mono); + let summary = extract_reachable_callable_details(&store, pkg_id); + + assert!( + summary.contains("callable Helper: input_ty=Int, output_ty=Int"), + "unexpected summary: {summary}" + ); + assert!( + summary.contains("callable Main: input_ty=Unit, output_ty=Int"), + "unexpected summary: {summary}" + ); + assert!( + summary.contains("body: block_ty=Int"), + "unexpected summary: {summary}" + ); + + assert_callable_body_terminal_expr_matches_block_type(&store, pkg_id, "Helper"); + assert_callable_body_terminal_expr_matches_block_type(&store, pkg_id, "Main"); + } +} diff --git a/source/compiler/qsc_fir_transforms/src/tuple_compare_lower.rs b/source/compiler/qsc_fir_transforms/src/tuple_compare_lower.rs new file mode 100644 index 0000000000..607c995cea --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/tuple_compare_lower.rs @@ -0,0 +1,218 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! Tuple comparison lowering pass. +//! +//! Rewrites `BinOp(Eq/Neq)` on non-empty tuple-typed operands into +//! element-wise scalar comparisons joined by `AndL`/`OrL`. +//! +//! Establishes [`crate::invariants::InvariantLevel::PostTupleCompLower`]: +//! no `BinOp(Eq/Neq)` remains on tuple-typed operands in reachable code. +//! +//! # Pipeline position +//! +//! Runs after UDT erasure (which converts structs to tuples) and before +//! SROA (which decomposes tuple-typed locals into scalars). This ordering +//! is critical: SROA cannot decompose bindings that have whole-value uses +//! such as tuple equality, so this pass eliminates those uses first. +//! +//! # Input patterns +//! +//! - `BinOp(Eq | Neq, lhs, rhs)` where both operands are non-empty +//! `Ty::Tuple`. +//! +//! # Rewrites +//! +//! ```text +//! // Before +//! BinOp(Eq, (a, b, c), (x, y, z)) +//! +//! // After +//! AndL(AndL(Eq(a, x), Eq(b, y)), Eq(c, z)) +//! ``` +//! +//! Nested tuple operands recurse through `lower_single_cmp` so element +//! comparisons are themselves lowered before being folded. +//! +//! # Notes +//! +//! - Synthesized expressions use `EMPTY_EXEC_RANGE` (zero-length exec +//! graph range). The [`crate::exec_graph_rebuild`] pass runs afterward +//! and rebuilds correct exec graphs for the entire package, including +//! the synthesized `AndL`/`OrL` nodes **and** any synthesized +//! `Field(..)` accesses produced by `extract_or_field`. + +#[cfg(test)] +mod tests; + +#[cfg(all(test, feature = "slow-proptest-tests"))] +mod semantic_equivalence_tests; + +use crate::fir_builder::{alloc_bin_op_expr, alloc_field_expr, reachable_local_callables}; +use crate::reachability::collect_reachable_from_entry; +use crate::walk_utils::collect_expr_ids_in_entry_and_local_callables; +use qsc_fir::assigner::Assigner; +use qsc_fir::fir::{BinOp, ExprId, ExprKind, Package, PackageId, PackageLookup, PackageStore}; +use qsc_fir::ty::{Prim, Ty}; + +/// Rewrites `BinOp(Eq/Neq)` on non-empty tuple-typed operands into +/// element-wise comparisons in the entry-reachable portion of a package. +/// +/// Scope and idempotence: +/// +/// - Scans only callables whose item reference lives in the target +/// package; cross-package items stay untouched. +/// - Returns early without modification when the target package has no +/// entry expression, since nothing is reachable to rewrite. +/// - Rewrites each matched expression **in place**, preserving its +/// original `ExprId` so downstream references (including +/// execution-graph re-linking) stay stable. +pub fn lower_tuple_comparisons( + store: &mut PackageStore, + package_id: PackageId, + assigner: &mut Assigner, +) { + let package = store.get(package_id); + if package.entry.is_none() { + return; + } + + let reachable = collect_reachable_from_entry(store, package_id); + let package = store.get(package_id); + + // Collect reachable local callable item IDs. + let local_item_ids: Vec<_> = reachable_local_callables(package, package_id, &reachable) + .map(|(item_id, _)| item_id) + .collect(); + + // Collect all ExprIds in entry expression + reachable callable bodies. + let expr_ids = collect_expr_ids_in_entry_and_local_callables(package, &local_item_ids); + + let package = store.get_mut(package_id); + for expr_id in expr_ids { + lower_single_cmp(package, assigner, expr_id); + } +} + +/// Rewrites a single `BinOp(Eq/Neq)` expression with tuple-typed operands +/// into element-wise comparisons. +/// +/// # Before +/// ```text +/// BinOp(Eq, lhs: (A, B), rhs: (A, B)) +/// ``` +/// # After +/// ```text +/// BinOp(AndL, BinOp(Eq, lhs.0, rhs.0), BinOp(Eq, lhs.1, rhs.1)) +/// ``` +/// +/// # Mutations +/// - Rewrites `expr_id`'s `ExprKind` in place. +/// - Allocates field-access and comparison `Expr` nodes through `assigner`. +fn lower_single_cmp(package: &mut Package, assigner: &mut Assigner, expr_id: ExprId) { + let expr = package.get_expr(expr_id); + let (op, lhs_id, rhs_id) = match &expr.kind { + ExprKind::BinOp(op @ (BinOp::Eq | BinOp::Neq), lhs, rhs) => (*op, *lhs, *rhs), + _ => return, + }; + let span = expr.span; + + let lhs_ty = package.get_expr(lhs_id).ty.clone(); + let elem_tys = match &lhs_ty { + Ty::Tuple(elems) if !elems.is_empty() => elems.clone(), + _ => return, + }; + + let joiner = match op { + BinOp::Eq => BinOp::AndL, + BinOp::Neq => BinOp::OrL, + // Guarded by the outer `matches!(op, BinOp::Eq | BinOp::Neq)` + // discriminant above; any other operator exits at the `match + // &expr.kind` early-return. + _ => unreachable!(), + }; + + // Extract element ExprIds: use existing Tuple element IDs when available, + // otherwise synthesize Field accesses. This avoids creating Field + // expressions with empty exec graph ranges on static tuple literals, + // which would cause issues in the partial evaluator's static-classical + // entry-eval path + let lhs_elems = extract_or_field(package, assigner, lhs_id, &elem_tys, span); + let rhs_elems = extract_or_field(package, assigner, rhs_id, &elem_tys, span); + + // Build element-wise comparisons. + let mut cmp_ids: Vec = Vec::with_capacity(elem_tys.len()); + for i in 0..elem_tys.len() { + let elem_cmp = { + let lhs = lhs_elems[i]; + let rhs = rhs_elems[i]; + let ty = Ty::Prim(Prim::Bool); + alloc_bin_op_expr(package, assigner, op, lhs, rhs, ty, span) + }; + // Recursively lower nested tuple comparisons. + lower_single_cmp(package, assigner, elem_cmp); + cmp_ids.push(elem_cmp); + } + + // Fold element comparisons left-to-right with the joiner. + let result_id = fold_left(package, assigner, &cmp_ids, joiner, span); + + // Rewrite the original expression in-place. + let result_expr = package.get_expr(result_id); + let result_kind = result_expr.kind.clone(); + let target = package.exprs.get_mut(expr_id).expect("expr exists"); + target.kind = result_kind; + target.ty = Ty::Prim(Prim::Bool); +} + +/// Extracts element `ExprId`s from a tuple-typed expression. +/// +/// If the expression is `ExprKind::Tuple(es)`, returns the element IDs +/// directly. Otherwise, synthesizes `Field(expr, Path([i]))` for each +/// element. +fn extract_or_field( + package: &mut Package, + assigner: &mut Assigner, + tuple_expr_id: ExprId, + elem_tys: &[Ty], + span: qsc_data_structures::span::Span, +) -> Vec { + let expr = package.get_expr(tuple_expr_id); + if let ExprKind::Tuple(es) = &expr.kind { + assert_eq!( + es.len(), + elem_tys.len(), + "tuple expression arity must match type arity" + ); + return es.clone(); + } + elem_tys + .iter() + .enumerate() + .map(|(i, ty)| { + let elem_ty = ty.clone(); + alloc_field_expr(package, assigner, tuple_expr_id, i, elem_ty, span) + }) + .collect() +} + +/// Folds expressions left-to-right with a joiner operator. +/// +/// `[a, b, c]` with `AndL` becomes `AndL(AndL(a, b), c)`. +fn fold_left( + package: &mut Package, + assigner: &mut Assigner, + exprs: &[ExprId], + joiner: BinOp, + span: qsc_data_structures::span::Span, +) -> ExprId { + assert!(!exprs.is_empty(), "fold_left requires at least one expr"); + let mut acc = exprs[0]; + for &e in &exprs[1..] { + acc = { + let ty = Ty::Prim(Prim::Bool); + alloc_bin_op_expr(package, assigner, joiner, acc, e, ty, span) + }; + } + acc +} diff --git a/source/compiler/qsc_fir_transforms/src/tuple_compare_lower/semantic_equivalence_tests.rs b/source/compiler/qsc_fir_transforms/src/tuple_compare_lower/semantic_equivalence_tests.rs new file mode 100644 index 0000000000..b1e3144e8a --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/tuple_compare_lower/semantic_equivalence_tests.rs @@ -0,0 +1,155 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use indoc::formatdoc; +use indoc::indoc; +use proptest::prelude::*; + +#[test] +fn tuple_eq_comparison_preserves_semantics() { + crate::test_utils::check_semantic_equivalence(indoc! {r#" + namespace Test { + @EntryPoint() + function Main() : Bool { + let a = (1, 2); + let b = (1, 2); + a == b + } + } + "#}); +} + +#[test] +fn tuple_neq_comparison_preserves_semantics() { + crate::test_utils::check_semantic_equivalence(indoc! {r#" + namespace Test { + @EntryPoint() + function Main() : Bool { + let a = (1, 2); + let b = (3, 4); + a != b + } + } + "#}); +} + +#[test] +fn nested_tuple_eq_preserves_semantics() { + crate::test_utils::check_semantic_equivalence(indoc! {r#" + namespace Test { + @EntryPoint() + function Main() : Bool { + let a = ((1, 2), 3); + let b = ((1, 2), 3); + a == b + } + } + "#}); +} + +fn flat_int_tuple_comparison_pattern() -> impl Strategy { + ( + 2usize..=4, + prop::bool::ANY, + prop::collection::vec(-20i64..=20, 4), + prop::collection::vec(-20i64..=20, 4), + ) + .prop_map(|(width, use_not_equal, left_values, right_values)| { + let left_tuple = left_values + .into_iter() + .take(width) + .map(|value| value.to_string()) + .collect::>() + .join(", "); + let right_tuple = right_values + .into_iter() + .take(width) + .map(|value| value.to_string()) + .collect::>() + .join(", "); + let operator = if use_not_equal { "!=" } else { "==" }; + + formatdoc! {r#" + namespace Test {{ + @EntryPoint() + function Main() : Bool {{ + let left = ({left_tuple}); + let right = ({right_tuple}); + left {operator} right + }} + }} + "#} + }) +} + +proptest! { + #![proptest_config(ProptestConfig::with_cases(50))] + + #[test] + fn flat_int_tuple_comparison_preserves_semantics(source in flat_int_tuple_comparison_pattern()) { + crate::test_utils::check_semantic_equivalence(&source); + } +} + +fn qsharp_bool(value: bool) -> &'static str { + if value { "true" } else { "false" } +} + +fn nested_mixed_tuple_comparison_strategy() -> impl Strategy { + ( + prop::bool::ANY, + -16i64..=16, + prop::bool::ANY, + -16i64..=16, + prop::bool::ANY, + -16i64..=16, + -16i64..=16, + prop::bool::ANY, + -16i64..=16, + prop::bool::ANY, + -16i64..=16, + ) + .prop_map( + |( + use_not_equal, + left_a, + left_flag_a, + left_double, + left_flag_b, + left_c, + right_a, + right_flag_a, + right_double, + right_flag_b, + right_c, + )| { + let operator = if use_not_equal { "!=" } else { "==" }; + let left_flag_a = qsharp_bool(left_flag_a); + let left_flag_b = qsharp_bool(left_flag_b); + let right_flag_a = qsharp_bool(right_flag_a); + let right_flag_b = qsharp_bool(right_flag_b); + + formatdoc! {r#" + namespace Test {{ + @EntryPoint() + function Main() : Bool {{ + let left = (({left_a}, {left_flag_a}), ({left_double}.0, ({left_flag_b}, {left_c}))); + let right = (({right_a}, {right_flag_a}), ({right_double}.0, ({right_flag_b}, {right_c}))); + left {operator} right + }} + }} + "#} + }, + ) +} + +proptest! { + #![proptest_config(ProptestConfig::with_cases(32))] + + #[test] + fn nested_mixed_tuple_comparison_preserves_semantics( + source in nested_mixed_tuple_comparison_strategy() + ) { + crate::test_utils::check_semantic_equivalence(&source); + } +} diff --git a/source/compiler/qsc_fir_transforms/src/tuple_compare_lower/tests.rs b/source/compiler/qsc_fir_transforms/src/tuple_compare_lower/tests.rs new file mode 100644 index 0000000000..9cde01f6e2 --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/tuple_compare_lower/tests.rs @@ -0,0 +1,436 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use crate::test_utils::{PipelineStage, compile_and_run_pipeline_to}; +use expect_test::{Expect, expect}; +use indoc::indoc; +use qsc_fir::assigner::Assigner; +use qsc_fir::fir::{BinOp, CallableImpl, ExprKind, ItemKind, PackageLookup, StmtKind}; + +/// Runs the pipeline through tuple comparison lowering and extracts a summary +/// of the expression tree for the entry callable's body statements. +fn check(source: &str, expect: &Expect) { + let (store, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::TupleCompLower); + let result = extract_expr_summary(&store, pkg_id); + expect.assert_eq(&result); +} + +fn check_callable_expr_summary(source: &str, expect: &Expect) { + let (store, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::TupleCompLower); + let result = extract_callable_expr_summary(&store, pkg_id); + expect.assert_eq(&result); +} + +/// Extracts a summary of expression kinds in the entry callable's body, +/// focusing on `BinOp` expressions to verify lowering. +fn extract_expr_summary( + store: &qsc_fir::fir::PackageStore, + pkg_id: qsc_fir::fir::PackageId, +) -> String { + let package = store.get(pkg_id); + let reachable = crate::reachability::collect_reachable_from_entry(store, pkg_id); + let mut lines: Vec = Vec::new(); + + for store_id in &reachable { + if store_id.package != pkg_id { + continue; + } + let item = package.get_item(store_id.item); + if let ItemKind::Callable(decl) = &item.kind + && let CallableImpl::Spec(spec) = &decl.implementation + { + let block = package.get_block(spec.body.block); + for &stmt_id in &block.stmts { + let stmt = package.get_stmt(stmt_id); + match &stmt.kind { + StmtKind::Expr(e) | StmtKind::Semi(e) => { + lines.push(format_expr(package, *e, 0)); + } + StmtKind::Local(_, _, e) => { + lines.push(format!("local init: {}", format_expr(package, *e, 0))); + } + StmtKind::Item(_) => {} + } + } + } + } + + lines.sort(); + lines.join("\n") +} + +fn extract_callable_expr_summary( + store: &qsc_fir::fir::PackageStore, + pkg_id: qsc_fir::fir::PackageId, +) -> String { + let package = store.get(pkg_id); + let reachable = crate::reachability::collect_reachable_from_entry(store, pkg_id); + let mut callables = Vec::new(); + + for store_id in &reachable { + if store_id.package != pkg_id { + continue; + } + let item = package.get_item(store_id.item); + if let ItemKind::Callable(decl) = &item.kind + && let CallableImpl::Spec(spec) = &decl.implementation + { + let block = package.get_block(spec.body.block); + let mut lines = vec![format!("callable {}:", decl.name.name)]; + for &stmt_id in &block.stmts { + let stmt = package.get_stmt(stmt_id); + match &stmt.kind { + StmtKind::Expr(e) => { + lines.push(" expr:".to_string()); + lines.push(format_expr(package, *e, 2)); + } + StmtKind::Semi(e) => { + lines.push(" semi:".to_string()); + lines.push(format_expr(package, *e, 2)); + } + StmtKind::Local(_, _, e) => { + lines.push(" local init:".to_string()); + lines.push(format_expr(package, *e, 2)); + } + StmtKind::Item(_) => {} + } + } + callables.push(lines.join("\n")); + } + } + + callables.sort(); + callables.join("\n") +} + +/// Formats an expression recursively, showing `BinOp` structure. +fn format_expr( + package: &qsc_fir::fir::Package, + expr_id: qsc_fir::fir::ExprId, + depth: usize, +) -> String { + let expr = package.get_expr(expr_id); + let indent = " ".repeat(depth); + match &expr.kind { + ExprKind::BinOp(op, lhs, rhs) => { + let op_str = match op { + BinOp::Eq => "Eq", + BinOp::Neq => "Neq", + BinOp::AndL => "AndL", + BinOp::OrL => "OrL", + _ => "Other", + }; + format!( + "{indent}BinOp({op_str}, ty={}):\n{}\n{}", + expr.ty, + format_expr(package, *lhs, depth + 1), + format_expr(package, *rhs, depth + 1), + ) + } + ExprKind::Field(target, field) => { + format!("{indent}Field({}, {field}, ty={})", target, expr.ty) + } + ExprKind::Tuple(es) => { + let elems: Vec = es.iter().map(|e| format!("{e}")).collect(); + format!("{indent}Tuple([{}], ty={})", elems.join(", "), expr.ty) + } + ExprKind::Var(res, _) => { + format!("{indent}Var({res}, ty={})", expr.ty) + } + ExprKind::Lit(lit) => { + format!("{indent}Lit({lit:?}, ty={})", expr.ty) + } + ExprKind::Call(callee, args) => { + format!("{indent}Call({callee}, {args}, ty={})", expr.ty) + } + _ => { + format!("{indent}Expr({expr_id}, ty={})", expr.ty) + } + } +} + +/// Verifies the full pipeline succeeds (including QIR generation) for dynamic +/// tuple comparisons. +fn generate_qir(source: &str) -> String { + use qsc_codegen::qir::fir_to_qir; + use qsc_data_structures::target::TargetCapabilityFlags; + use qsc_partial_eval::ProgramEntry; + + let capabilities = TargetCapabilityFlags::Adaptive | TargetCapabilityFlags::IntegerComputations; + let (store, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::Full); + let package = store.get(pkg_id); + let entry = ProgramEntry { + exec_graph: package.entry_exec_graph.clone(), + expr: ( + pkg_id, + package + .entry + .expect("package must have an entry expression"), + ) + .into(), + }; + let compute_properties = qsc_rca::Analyzer::init(&store, capabilities).analyze_all(); + fir_to_qir(&store, capabilities, &compute_properties, &entry).expect("QIR generation failed") +} + +#[test] +fn dynamic_tuple_eq_decomposed() { + // Tuple comparison with Result values decomposes into element-wise AndL. + check( + "operation Main() : Bool { + use (q0, q1) = (Qubit(), Qubit()); + let (r0, r1) = (M(q0), M(q1)); + (r0, r1) == (Zero, Zero) + }", + &expect![[r#" + Call(27, 28, ty=Unit) + Call(30, 31, ty=Unit) + Var(Local 7, ty=Bool) + local init: BinOp(AndL, ty=Bool): + BinOp(Eq, ty=Bool): + Var(Local 5, ty=Result) + Lit(Result(Zero), ty=Result) + BinOp(Eq, ty=Bool): + Var(Local 6, ty=Result) + Lit(Result(Zero), ty=Result) + local init: Call(4, 5, ty=Qubit) + local init: Call(7, 8, ty=Qubit) + local init: Tuple([10, 11], ty=(Qubit, Qubit)) + local init: Tuple([13, 16], ty=(Result, Result))"#]], + ); +} + +#[test] +fn dynamic_tuple_neq_decomposed() { + // Tuple inequality with Result values decomposes into element-wise OrL. + check( + "operation Main() : Bool { + use (q0, q1) = (Qubit(), Qubit()); + let (r0, r1) = (M(q0), M(q1)); + (r0, r1) != (Zero, Zero) + }", + &expect![[r#" + Call(27, 28, ty=Unit) + Call(30, 31, ty=Unit) + Var(Local 7, ty=Bool) + local init: BinOp(OrL, ty=Bool): + BinOp(Neq, ty=Bool): + Var(Local 5, ty=Result) + Lit(Result(Zero), ty=Result) + BinOp(Neq, ty=Bool): + Var(Local 6, ty=Result) + Lit(Result(Zero), ty=Result) + local init: Call(4, 5, ty=Qubit) + local init: Call(7, 8, ty=Qubit) + local init: Tuple([10, 11], ty=(Qubit, Qubit)) + local init: Tuple([13, 16], ty=(Result, Result))"#]], + ); +} + +#[test] +fn classical_tuple_eq_decomposed() { + // Purely classical tuple comparison IS now decomposed into element-wise AndL. + check( + "function Main() : Bool { + (1, 2) == (3, 4) + }", + &expect![[r#" + BinOp(AndL, ty=Bool): + BinOp(Eq, ty=Bool): + Lit(Int(1), ty=Int) + Lit(Int(3), ty=Int) + BinOp(Eq, ty=Bool): + Lit(Int(2), ty=Int) + Lit(Int(4), ty=Int)"#]], + ); +} + +#[test] +fn mixed_classical_dynamic_tuple_decomposed() { + // Tuple containing both classical and dynamic types IS decomposed + // because it contains Result. + check( + "operation Main() : Bool { + use q = Qubit(); + let r = M(q); + (1, r) == (0, Zero) + }", + &expect![[r#" + Call(17, 18, ty=Unit) + Var(Local 3, ty=Bool) + local init: BinOp(AndL, ty=Bool): + BinOp(Eq, ty=Bool): + Lit(Int(1), ty=Int) + Lit(Int(0), ty=Int) + BinOp(Eq, ty=Bool): + Var(Local 2, ty=Result) + Lit(Result(Zero), ty=Result) + local init: Call(4, 5, ty=Qubit) + local init: Call(7, 8, ty=Result)"#]], + ); +} + +#[test] +fn dynamic_tuple_eq_qir_succeeds() { + // Verify the full pipeline and QIR generation succeeds for tuple + // comparison with Result values. + let qir = generate_qir( + "operation Main() : Bool { + use (q0, q1) = (Qubit(), Qubit()); + let (r0, r1) = (M(q0), M(q1)); + (r0, r1) == (Zero, Zero) + }", + ); + // QIR should be non-empty, meaning the pipeline succeeded. + assert!(!qir.is_empty(), "QIR generation should succeed"); +} + +#[test] +fn nested_tuple_eq_recursively_decomposes_inner_elements() { + check( + indoc! {" + operation Main() : Bool { + use q1 = Qubit(); + use q2 = Qubit(); + let a = (M(q1), M(q2)); + let b = (M(q1), M(q2)); + (a, a) == (b, b) + } + "}, + &expect![[r#" + Call(31, 32, ty=Unit) + Call(34, 35, ty=Unit) + Var(Local 5, ty=Bool) + local init: BinOp(AndL, ty=Bool): + BinOp(AndL, ty=Bool): + BinOp(Eq, ty=Bool): + Field(25, Path([0]), ty=Result) + Field(28, Path([0]), ty=Result) + BinOp(Eq, ty=Bool): + Field(25, Path([1]), ty=Result) + Field(28, Path([1]), ty=Result) + BinOp(AndL, ty=Bool): + BinOp(Eq, ty=Bool): + Field(26, Path([0]), ty=Result) + Field(29, Path([0]), ty=Result) + BinOp(Eq, ty=Bool): + Field(26, Path([1]), ty=Result) + Field(29, Path([1]), ty=Result) + local init: Call(4, 5, ty=Qubit) + local init: Call(7, 8, ty=Qubit) + local init: Tuple([10, 13], ty=(Result, Result)) + local init: Tuple([17, 20], ty=(Result, Result))"#]], + ); +} + +#[test] +fn nested_tuple_neq_recursively_decomposes_inner_elements() { + check( + indoc! {" + function Main() : Bool { + ((1, 2), (3, 4)) != ((1, 5), (3, 4)) + } + "}, + &expect![[r#"BinOp(OrL, ty=Bool): + BinOp(OrL, ty=Bool): + BinOp(Neq, ty=Bool): + Lit(Int(1), ty=Int) + Lit(Int(1), ty=Int) + BinOp(Neq, ty=Bool): + Lit(Int(2), ty=Int) + Lit(Int(5), ty=Int) + BinOp(OrL, ty=Bool): + BinOp(Neq, ty=Bool): + Lit(Int(3), ty=Int) + Lit(Int(3), ty=Int) + BinOp(Neq, ty=Bool): + Lit(Int(4), ty=Int) + Lit(Int(4), ty=Int)"#]], + ); +} + +#[test] +fn helper_callable_tuple_neq_is_lowered() { + check_callable_expr_summary( + indoc! {" + function Helper() : Bool { + (0, 0) != (0, 1) + } + + function Main() : Bool { + Helper() + } + "}, + &expect![[r#"callable Helper: + expr: + BinOp(OrL, ty=Bool): + BinOp(Neq, ty=Bool): + Lit(Int(0), ty=Int) + Lit(Int(0), ty=Int) + BinOp(Neq, ty=Bool): + Lit(Int(0), ty=Int) + Lit(Int(1), ty=Int) +callable Main: + expr: + Call(11, 12, ty=Bool)"#]], + ); +} + +#[test] +fn empty_tuple_eq_unchanged_no_decomposition() { + check( + indoc! {" + function Main() : Bool { + () == () + } + "}, + &expect![[r#" + BinOp(Eq, ty=Bool): + Tuple([], ty=Unit) + Tuple([], ty=Unit)"#]], + ); +} + +#[test] +fn tuple_compare_lower_is_idempotent() { + let source = indoc! {" + namespace Test { + @EntryPoint() + operation Main() : Bool { + use (q0, q1) = (Qubit(), Qubit()); + let pair = (M(q0), M(q1)); + pair == pair + } + } + "}; + let (mut store, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::TupleCompLower); + let first = crate::pretty::write_package_qsharp(&store, pkg_id); + let mut assigner = Assigner::from_package(store.get(pkg_id)); + crate::tuple_compare_lower::lower_tuple_comparisons(&mut store, pkg_id, &mut assigner); + let second = crate::pretty::write_package_qsharp(&store, pkg_id); + assert_eq!(first, second, "tuple_compare_lower should be idempotent"); +} + +#[test] +fn entry_expression_tuple_comparison_is_lowered() { + // Tuple comparison in an @EntryPoint callable is lowered correctly. + // Documents that the entry expression path is covered by tuple_compare_lower. + check( + indoc! {" + namespace Test { + @EntryPoint() + operation Main() : Bool { + (1, 2) == (1, 2) + } + } + "}, + &expect![[r#" + BinOp(AndL, ty=Bool): + BinOp(Eq, ty=Bool): + Lit(Int(1), ty=Int) + Lit(Int(1), ty=Int) + BinOp(Eq, ty=Bool): + Lit(Int(2), ty=Int) + Lit(Int(2), ty=Int)"#]], + ); +} diff --git a/source/compiler/qsc_fir_transforms/src/udt_erase.rs b/source/compiler/qsc_fir_transforms/src/udt_erase.rs new file mode 100644 index 0000000000..abb11031e3 --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/udt_erase.rs @@ -0,0 +1,800 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! UDT erasure pass. +//! +//! Replaces every `Ty::Udt` in the entry-reachable package closure with its +//! pure tuple or scalar type (via `get_pure_ty()`) and converts +//! `ExprKind::Struct` construction expressions into tuple or scalar +//! expressions. Also eliminates UDT constructor calls (`ExprKind::Call` +//! whose callee is an `ItemKind::Ty` item) and lowers +//! `ExprKind::UpdateField` and `ExprKind::AssignField` with `Field::Path` +//! into explicit tuple constructions with field extractions. Additionally, +//! lowers `ExprKind::Field` read access expressions on scalar-erased +//! single-field newtypes. After this pass, no `Ty::Udt`, `ExprKind::Struct`, +//! UDT constructor call, UDT-targeted `UpdateField`/`AssignField`, or +//! `Field::Path` on non-tuple types remains in the target package or in any +//! package that contains an entry-reachable callable. +//! +//! Establishes [`crate::invariants::InvariantLevel::PostUdtErase`]. +//! +//! This must run before partial evaluation and backend code generation, which +//! may inspect reachable cross-package FIR but do not support UDT types or +//! `ExprKind::Struct` in the code they consume. +//! +//! UDT erasure is a standard type-erasure technique common in ML-family +//! compilers and functional languages targeting lower-level IRs. +//! +//! # Input patterns +//! +//! - `ExprKind::Struct(Udt, copy_opt, fields)` — UDT construction (with or +//! without a copy-update source). +//! - `ExprKind::UpdateField(record, Field::Path, replace)` / `AssignField` +//! — field-path-based record updates. +//! - Any expression, pattern, block, or callable signature carrying a +//! `Ty::Udt`. +//! +//! # Rewrites +//! +//! Construction of `newtype Pair = (Int, Int); new Pair { First = 1, Second = 2 }`: +//! +//! ```text +//! // Before +//! Struct(Pair, None, [First = 1, Second = 2]) +//! +//! // After +//! Tuple([1, 2]) +//! ``` +//! +//! Copy-update `new Pair { ...src, First = 9 }`: +//! +//! ```text +//! // Before +//! Struct(Pair, Some(src), [First = 9]) +//! +//! // After +//! Tuple([9, Field(src, Path([1]))]) +//! ``` +//! +//! Update-field `record w/ ::First <- 9`: +//! +//! ```text +//! // Before +//! UpdateField(record, Path([0]), 9) +//! +//! // After +//! Tuple([9, Field(record, Path([1]))]) +//! ``` +//! +//! # Notes +//! +//! - Scope: the target package and every package reachable from its entry +//! expression are mutated in place. Cross-package UDT resolution still +//! uses the whole store via the UDT cache. +//! - Synthesized expressions use `EMPTY_EXEC_RANGE`; +//! [`crate::exec_graph_rebuild`] rebuilds correct exec graphs at the end +//! of the pipeline. + +#[cfg(test)] +mod tests; + +#[cfg(all(test, feature = "slow-proptest-tests"))] +mod semantic_equivalence_tests; + +use crate::cloner::FirCloner; +use crate::{EMPTY_EXEC_RANGE, reachability::collect_reachable_package_closure_from_entry}; +use qsc_data_structures::span::Span; +use qsc_fir::assigner::Assigner; +use qsc_fir::fir::{ + BlockId, Expr, ExprId, ExprKind, Field, FieldAssign, FieldPath, ItemKind, LocalItemId, Package, + PackageId, PackageStore, PatId, Res, +}; +use qsc_fir::ty::{Arrow, Ty}; + +use rustc_hash::FxHashMap; + +/// Maps `(PackageId, LocalItemId)` → pure `Ty` for every UDT definition +/// in the store. +type UdtCache = FxHashMap<(PackageId, LocalItemId), Ty>; + +/// Erases all `Ty::Udt` types and `ExprKind::Struct` expressions in the +/// target package's reachable package closure, while resolving UDT +/// definitions from the whole store. +/// +/// Returns immediately without modification if the target package has no +/// entry expression (nothing is reachable to rewrite). +pub fn erase_udts(store: &mut PackageStore, package_id: PackageId, assigner: &mut Assigner) { + let package = store.get(package_id); + if package.entry.is_none() { + return; + } + + // Build a resolution cache from all UDT items across all packages. + let udt_cache = build_udt_cache(store); + + // Erase UDTs in the target package and in any package that contains an + // entry-reachable callable. UDT definition lookup still spans the whole + // store so cross-package references resolve correctly. + let pkg_ids: Vec = collect_reachable_package_closure_from_entry(store, package_id) + .into_iter() + .collect(); + for pkg_id in pkg_ids { + if pkg_id == package_id { + // Use the threaded assigner for the target package. + let owned = std::mem::take(assigner); + let mut cloner = FirCloner::from_assigner(owned); + erase_udts_in_package(store.get_mut(pkg_id), &udt_cache, &mut cloner); + *assigner = cloner.into_assigner(); + } else { + let mut cloner = FirCloner::new(store.get(pkg_id)); + erase_udts_in_package(store.get_mut(pkg_id), &udt_cache, &mut cloner); + } + } +} + +/// Erases UDT types and struct expressions in a single package, rewriting +/// every expression type, pattern type, block type, callable signature, +/// and struct construction in place. Called once per package in the +/// entry-reachable closure. +/// +/// # Before +/// ```text +/// Expr { ty: Udt(MyStruct), kind: Struct(res, None, fields) } +/// Pat { ty: Udt(MyStruct) } +/// Block { ty: Udt(MyStruct) } +/// ``` +/// # After +/// ```text +/// Expr { ty: Tuple([Int, Bool]), kind: Tuple([v0, v1]) } +/// Pat { ty: Tuple([Int, Bool]) } +/// Block { ty: Tuple([Int, Bool]) } +/// ``` +/// +/// # Mutations +/// - Rewrites `Expr.ty`, `Expr.kind`, `Pat.ty`, `Block.ty`, and callable +/// output types in place. +/// - Allocates field-extraction `Expr` nodes through `cloner` for +/// copy-update and field-update lowering. +fn erase_udts_in_package(package: &mut Package, udt_cache: &UdtCache, cloner: &mut FirCloner) { + // Rewrite all expression types and Struct expressions. + let expr_ids: Vec = package.exprs.iter().map(|(id, _)| id).collect(); + for expr_id in expr_ids { + // Rewrite the expression's type. + let expr = package.exprs.get(expr_id).expect("expr should exist"); + let new_ty = resolve_ty(udt_cache, &expr.ty); + let kind = expr.kind.clone(); + let expr_span = expr.span; + + let expr_mut = package.exprs.get_mut(expr_id).expect("expr should exist"); + expr_mut.ty = new_ty; + + // Convert Struct expressions to Tuple expressions. + if let ExprKind::Struct(_res, copy, fields) = &kind { + if let Some(copy_id) = copy { + lower_copy_update_struct( + package, cloner, udt_cache, expr_id, *copy_id, fields, expr_span, + ); + } else { + let mut indexed: Vec<(usize, ExprId)> = fields + .iter() + .filter_map(|fa| { + if let Field::Path(FieldPath { indices }) = &fa.field { + indices.first().map(|&idx| (idx, fa.value)) + } else { + None + } + }) + .collect(); + indexed.sort_by_key(|(idx, _)| *idx); + let values: Vec = indexed.into_iter().map(|(_, v)| v).collect(); + + if values.len() == 1 { + // The expression type has already been resolved to the + // UDT's pure type. For struct-syntax UDTs the pure type + // is Tuple([T]), while for `newtype X = T` it is scalar T. + let is_tuple_ty = matches!( + &package.exprs.get(expr_id).expect("expr should exist").ty, + Ty::Tuple(_) + ); + if is_tuple_ty { + // Struct syntax: pure type is Tuple([T]). Keep as + // tuple to match the pattern type. + let expr_mut = package.exprs.get_mut(expr_id).expect("expr should exist"); + expr_mut.kind = ExprKind::Tuple(values); + } else { + // newtype X = T: pure type is scalar T. Unwrap to + // the inner expression directly. + let inner_expr = package + .exprs + .get(values[0]) + .expect("inner expr should exist"); + let inner_kind = inner_expr.kind.clone(); + let inner_ty = inner_expr.ty.clone(); + let expr_mut = package.exprs.get_mut(expr_id).expect("expr should exist"); + expr_mut.kind = inner_kind; + expr_mut.ty = resolve_ty(udt_cache, &inner_ty); + } + } else { + // Multi-field UDT: replace with a tuple of the field + // values in declaration order. + let expr_mut = package.exprs.get_mut(expr_id).expect("expr should exist"); + expr_mut.kind = ExprKind::Tuple(values); + } + } + } + + // Eliminate UDT constructor calls. + eliminate_udt_constructor_call(package, udt_cache, expr_id, &kind); + + // Lower UpdateField and AssignField with Field::Path into tuple + // constructions. + lower_field_updates(package, cloner, udt_cache, expr_id, &kind, expr_span); + + // Lower Field read expressions on scalar-erased types (Field::Path + // expressions where the record type is not a tuple). + lower_scalar_field_read(package, udt_cache, expr_id, &kind); + } + + // Rewrite all pattern types. + let pat_ids: Vec = package.pats.iter().map(|(id, _)| id).collect(); + for pat_id in pat_ids { + let pat = package.pats.get(pat_id).expect("pat should exist"); + let new_ty = resolve_ty(udt_cache, &pat.ty); + let pat_mut = package.pats.get_mut(pat_id).expect("pat should exist"); + pat_mut.ty = new_ty; + } + + // Rewrite all block types. + let block_ids: Vec = package.blocks.iter().map(|(id, _)| id).collect(); + for block_id in block_ids { + let block = package.blocks.get(block_id).expect("block should exist"); + let new_ty = resolve_ty(udt_cache, &block.ty); + let block_mut = package + .blocks + .get_mut(block_id) + .expect("block should exist"); + block_mut.ty = new_ty; + } + + // Rewrite callable signatures (input pattern types are already handled + // above, but output types are stored separately in CallableDecl). + let item_ids: Vec = package.items.iter().map(|(id, _)| id).collect(); + for item_id in item_ids { + let item = package.items.get(item_id).expect("item should exist"); + if let ItemKind::Callable(decl) = &item.kind { + let new_output = resolve_ty(udt_cache, &decl.output); + if new_output != decl.output { + let item_mut = package.items.get_mut(item_id).expect("item should exist"); + if let ItemKind::Callable(decl_mut) = &mut item_mut.kind { + decl_mut.output = new_output; + } + } + } + } +} + +/// Eliminates a UDT constructor call if `kind` is `ExprKind::Call` whose +/// callee resolves to an `ItemKind::Ty` item. After type resolution the +/// constructor is an identity/wrapping function. +/// +/// # Before +/// ```text +/// Call(Var(Item(UdtConstructor)), arg) // e.g. MyStruct(42) +/// ``` +/// # After +/// ```text +/// arg // or Tuple([arg]) for trailing-comma newtypes +/// ``` +/// +/// # Mutations +/// - Rewrites `expr_id`'s `ExprKind` and `Ty` in place. +fn eliminate_udt_constructor_call( + package: &mut Package, + udt_cache: &UdtCache, + expr_id: ExprId, + kind: &ExprKind, +) { + let ExprKind::Call(callee_id, arg_id) = kind else { + return; + }; + let callee = package.exprs.get(*callee_id).expect("callee should exist"); + let ExprKind::Var(Res::Item(item_id), _) = &callee.kind else { + return; + }; + let Some(pure_ty) = udt_cache.get(&(item_id.package, item_id.item)) else { + return; + }; + let resolved_pure = resolve_ty(udt_cache, pure_ty); + let arg = package.exprs.get(*arg_id).expect("arg should exist"); + let arg_ty_resolved = resolve_ty(udt_cache, &arg.ty); + + if arg_ty_resolved != resolved_pure && matches!(&resolved_pure, Ty::Tuple(_)) { + // Trailing-comma single-field: scalar arg doesn't match + // Tuple([T]) pure type — wrap in a tuple. + let expr_mut = package.exprs.get_mut(expr_id).expect("expr should exist"); + expr_mut.kind = ExprKind::Tuple(vec![*arg_id]); + expr_mut.ty = resolved_pure; + } else { + // Argument type matches the erased constructor input (multi-field + // or scalar newtype) — replace the call with the argument. + let arg = package.exprs.get(*arg_id).expect("arg should exist"); + let arg_kind = arg.kind.clone(); + let arg_ty = arg.ty.clone(); + let expr_mut = package.exprs.get_mut(expr_id).expect("expr should exist"); + expr_mut.kind = arg_kind; + expr_mut.ty = resolve_ty(udt_cache, &arg_ty); + } +} + +/// Lowers a copy-update struct expression `new Foo { ...copy, X = val }` +/// into a tuple construction, replacing the expression kind in place. +/// +/// # Before +/// ```text +/// Struct(res, Some(copy_id), [FieldAssign(Path([1]), val)]) +/// ``` +/// # After +/// ```text +/// Tuple([Field(copy, Path([0])), val]) // field 0 extracted, field 1 replaced +/// ``` +/// +/// # Mutations +/// - Rewrites `expr_id`'s `ExprKind` and `Ty` in place. +/// - Allocates field-extraction `Expr` nodes through `cloner`. +fn lower_copy_update_struct( + package: &mut Package, + cloner: &mut FirCloner, + udt_cache: &UdtCache, + expr_id: ExprId, + copy_id: ExprId, + fields: &[FieldAssign], + span: Span, +) { + // Check for a whole-value replacement (single-field UDT where the + // field path is empty). + let whole_value_replace = fields.iter().find_map(|fa| { + if let Field::Path(FieldPath { indices }) = &fa.field + && indices.is_empty() + { + return Some(fa.value); + } + None + }); + + if let Some(replacement) = whole_value_replace { + // Single-field UDT (scalar type): the copy-update replaces the + // entire value. + let replace_expr = package + .exprs + .get(replacement) + .expect("replacement should exist"); + let replace_kind = replace_expr.kind.clone(); + let replace_ty = replace_expr.ty.clone(); + let expr_mut = package.exprs.get_mut(expr_id).expect("expr should exist"); + expr_mut.kind = replace_kind; + expr_mut.ty = resolve_ty(udt_cache, &replace_ty); + return; + } + + // Build a map of field index → replacement ExprId. + let updates: FxHashMap = fields + .iter() + .filter_map(|fa| { + if let Field::Path(FieldPath { indices }) = &fa.field { + indices.first().map(|&idx| (idx, fa.value)) + } else { + None + } + }) + .collect(); + + // Resolve the type of the copy source to determine the tuple + // structure (may not yet be resolved due to ID ordering). + let copy_raw_ty = &package + .exprs + .get(copy_id) + .expect("copy source should exist") + .ty; + let copy_ty = resolve_ty(udt_cache, copy_raw_ty); + + if let Ty::Tuple(elems) = ©_ty { + // Multi-field UDT: build a tuple with replacements at updated + // indices and field extractions elsewhere. + let mut field_ids = Vec::with_capacity(elems.len()); + for (j, elem_ty) in elems.iter().enumerate() { + if let Some(&replacement) = updates.get(&j) { + field_ids.push(replacement); + } else { + let field_id = alloc_field_expr(package, cloner, copy_id, j, elem_ty, span); + field_ids.push(field_id); + } + } + let expr_mut = package.exprs.get_mut(expr_id).expect("expr should exist"); + expr_mut.kind = ExprKind::Tuple(field_ids); + } else { + // Single-field UDTs erase to scalars. Depending on how the field + // path was lowered upstream, the update may arrive as an empty path, + // index 0, or a field marker that no longer carries a useful path. + // Any explicit field assignment on a scalar-erased copy-update must + // therefore replace the whole value. + if let Some(&replacement) = updates + .get(&0) + .or_else(|| fields.first().map(|fa| &fa.value)) + { + let replace_expr = package + .exprs + .get(replacement) + .expect("replacement should exist"); + let replace_kind = replace_expr.kind.clone(); + let replace_ty = replace_expr.ty.clone(); + let expr_mut = package.exprs.get_mut(expr_id).expect("expr should exist"); + expr_mut.kind = replace_kind; + expr_mut.ty = resolve_ty(udt_cache, &replace_ty); + } else { + // Defensive fallback: single-field UDT with no overrides after + // scalar erasure. The frontend should simplify copy-update + // expressions with zero overrides before they reach this point, + // making this path unreachable in practice. The fallback + // correctly propagates the copy source if it is ever hit. + debug_assert!( + false, + "copy-update with no field overrides on a scalar-erased single-field UDT \ + should be simplified before reaching lower_copy_update_struct" + ); + let copy_expr = package + .exprs + .get(copy_id) + .expect("copy source should exist"); + let copy_kind = copy_expr.kind.clone(); + let expr_mut = package.exprs.get_mut(expr_id).expect("expr should exist"); + expr_mut.kind = copy_kind; + } + } +} + +/// Lowers `UpdateField` and `AssignField` with `Field::Path` for a single +/// expression, replacing the expression kind in place. +/// +/// # Before +/// ```text +/// UpdateField(record, Field::Path([1]), new_val) // record w/ field 1 updated +/// AssignField(record, Field::Path([1]), new_val) // assign field 1 +/// ``` +/// # After +/// ```text +/// Tuple([Field(record, Path([0])), new_val]) // lowered tuple +/// Assign(record, Tuple([Field(record, Path([0])), new_val])) +/// ``` +/// +/// # Mutations +/// - Rewrites `expr_id`'s `ExprKind` in place. +/// - Allocates field-extraction and update `Expr` nodes through `cloner`. +fn lower_field_updates( + package: &mut Package, + cloner: &mut FirCloner, + udt_cache: &UdtCache, + expr_id: ExprId, + kind: &ExprKind, + span: Span, +) { + // Lower UpdateField(record, Field::Path(path), replace) into a + // tuple construction that extracts all non-updated fields from the + // record and inserts the replacement at the correct position. + if let ExprKind::UpdateField(record_id, Field::Path(path), replace_id) = kind { + // The record expression may not yet have its type resolved + // (FIR parent IDs are allocated before children, so record_id + // can be > expr_id). Resolve the type explicitly. + let record_raw_ty = &package + .exprs + .get(*record_id) + .expect("record should exist") + .ty; + let record_ty = resolve_ty(udt_cache, record_raw_ty); + let lowered = lower_update_field( + package, + cloner, + *record_id, + &path.indices, + *replace_id, + &record_ty, + span, + ); + let expr_mut = package.exprs.get_mut(expr_id).expect("expr should exist"); + expr_mut.kind = lowered; + } + + // Lower AssignField(record, Field::Path(path), value) into + // Assign(record, ). + if let ExprKind::AssignField(record_id, Field::Path(path), value_id) = kind { + let record_raw_ty = &package + .exprs + .get(*record_id) + .expect("record should exist") + .ty; + let record_ty = resolve_ty(udt_cache, record_raw_ty); + let lowered = lower_update_field( + package, + cloner, + *record_id, + &path.indices, + *value_id, + &record_ty, + span, + ); + let update_expr_id = cloner.alloc_expr(); + package.exprs.insert( + update_expr_id, + Expr { + id: update_expr_id, + span, + ty: record_ty, + kind: lowered, + exec_graph_range: EMPTY_EXEC_RANGE, + }, + ); + let expr_mut = package.exprs.get_mut(expr_id).expect("expr should exist"); + expr_mut.kind = ExprKind::Assign(*record_id, update_expr_id); + } +} + +/// Lowers `Field(record_id, Field::Path(_))` read expressions on scalar-erased +/// types, replacing the expression kind in place when the record type is not +/// a tuple. +/// +/// For scalar-erased single-field newtypes, the record type after erasure is +/// a primitive or other scalar type (e.g., `Prim(Int)`) rather than a tuple. +/// In this case, a field access like `w::x` is semantically an identity access +/// on the scalar value and should be replaced with a direct reference to the +/// record. This maintains the `PostUdtErase` invariant that `Field::Path` only +/// appears on `Ty::Tuple` records. +/// +/// For example: +/// - `newtype Wrapper = (x: Int); function Extract(w: Wrapper) : Int { w::x }` +/// - After UDT erasure: `w: Prim(Int)`, but `Field(w, Path([]))` remains +/// - This function replaces `Field(w, Path([]))` with `w` directly. +fn lower_scalar_field_read( + package: &mut Package, + udt_cache: &UdtCache, + expr_id: ExprId, + kind: &ExprKind, +) { + if let ExprKind::Field(record_id, Field::Path(_)) = kind { + let record_raw_ty = &package + .exprs + .get(*record_id) + .expect("record should exist") + .ty; + let record_ty = resolve_ty(udt_cache, record_raw_ty); + + // If the record type is not a tuple, this is a scalar-erased + // single-field newtype. Replace the field read with the record. + if !matches!(&record_ty, Ty::Tuple(_)) { + let record_expr = package.exprs.get(*record_id).expect("record should exist"); + let record_kind = record_expr.kind.clone(); + let record_ty_resolved = resolve_ty(udt_cache, &record_expr.ty); + let expr_mut = package.exprs.get_mut(expr_id).expect("expr should exist"); + expr_mut.kind = record_kind; + expr_mut.ty = record_ty_resolved; + } + } +} + +/// Builds a `(PackageId, LocalItemId) → pure Ty` cache for every UDT +/// definition in the package store so [`resolve_ty`] can perform O(1) +/// cross-package lookups. +fn build_udt_cache(store: &PackageStore) -> UdtCache { + let mut cache = FxHashMap::default(); + for (pkg_id, package) in store { + for (item_id, item) in &package.items { + if let ItemKind::Ty(_, udt) = &item.kind { + cache.insert((pkg_id, item_id), udt.get_pure_ty()); + } + } + } + cache +} + +/// Lowers `UpdateField(record, Field::Path(indices), replace)` into a tuple +/// construction that extracts all non-updated elements from `record` and +/// inserts `replace` at the position indicated by `indices`. +/// +/// For multi-level paths (`[i, j, ...]`), the lowering is recursive: the +/// element at index `i` is itself updated by lowering `[j, ...]` on the +/// extracted sub-record. +/// +/// For single-field UDTs (where the post-erasure record type is scalar, not +/// a tuple), the entire record is replaced by `replace`, and the result is +/// simply the replacement expression's kind. +fn lower_update_field( + package: &mut Package, + cloner: &mut FirCloner, + record_id: ExprId, + indices: &[usize], + replace_id: ExprId, + record_ty: &Ty, + span: Span, +) -> ExprKind { + match (indices, record_ty) { + // Single-level path on a tuple: build a new tuple with the + // replacement at `idx` and field extractions everywhere else. + (&[idx], Ty::Tuple(elems)) => { + debug_assert!( + idx < elems.len(), + "field path indices are guaranteed valid by frontend and prior-pass type checking" + ); + build_updated_tuple(package, cloner, record_id, idx, replace_id, elems, span) + } + + // Multi-level path on a tuple: recursively lower the inner update + // on the sub-record at index `idx`. + (&[idx, ref rest @ ..], Ty::Tuple(elems)) => { + debug_assert!( + idx < elems.len(), + "field path indices are guaranteed valid by frontend and prior-pass type checking" + ); + // Extract the sub-record at position idx. + let sub_id = alloc_field_expr(package, cloner, record_id, idx, &elems[idx], span); + + // Recursively lower the inner path on the sub-record. + let inner_kind = + lower_update_field(package, cloner, sub_id, rest, replace_id, &elems[idx], span); + + // Wrap the recursive result in a new expression. + let inner_result_id = cloner.alloc_expr(); + package.exprs.insert( + inner_result_id, + Expr { + id: inner_result_id, + span, + ty: elems[idx].clone(), + kind: inner_kind, + exec_graph_range: EMPTY_EXEC_RANGE, + }, + ); + + // Build the outer tuple with the recursively updated element. + build_updated_tuple( + package, + cloner, + record_id, + idx, + inner_result_id, + elems, + span, + ) + } + + // Empty path (single-field UDT whose wrapping was erased) or + // single-level path on a non-tuple scalar type: the entire record + // value is replaced. + ([] | &[_], _) => { + let replace_expr = package.exprs.get(replace_id).expect("replace should exist"); + replace_expr.kind.clone() + } + + // Fallback: retained as a guarded branch so invariants violations + // surface as a well-formed (but unlowered) UpdateField rather + // than a panic. Under a correct + // [`crate::invariants::InvariantLevel::PostUdtErase`] the path + // shape and record type will always match one of the arms above, + // making this arm unreachable. + _ => ExprKind::UpdateField( + record_id, + Field::Path(FieldPath { + indices: indices.to_vec(), + }), + replace_id, + ), + } +} + +/// Builds `ExprKind::Tuple(fields)` where `fields[update_idx]` is +/// `replace_id` and every other position is a freshly allocated +/// `ExprKind::Field(record_id, Path([j]))`. +/// +/// # Before +/// ```text +/// (no expression) +/// ``` +/// # After +/// ```text +/// Tuple([Field(record, Path([0])), replace, Field(record, Path([2]))]) +/// ``` +/// +/// # Mutations +/// - Allocates `Field` `Expr` nodes through `cloner` for non-updated positions. +fn build_updated_tuple( + package: &mut Package, + cloner: &mut FirCloner, + record_id: ExprId, + update_idx: usize, + replace_id: ExprId, + elems: &[Ty], + span: Span, +) -> ExprKind { + debug_assert!( + update_idx < elems.len(), + "field path indices are guaranteed valid by frontend and prior-pass type checking" + ); + let mut field_ids = Vec::with_capacity(elems.len()); + for (j, elem_ty) in elems.iter().enumerate() { + if j == update_idx { + field_ids.push(replace_id); + } else { + let field_id = alloc_field_expr(package, cloner, record_id, j, elem_ty, span); + field_ids.push(field_id); + } + } + ExprKind::Tuple(field_ids) +} + +/// Allocates a new `Expr` with `ExprKind::Field(record_id, Path([index]))`. +/// +/// # Mutations +/// - Inserts one `Expr` node through `cloner`. +fn alloc_field_expr( + package: &mut Package, + cloner: &mut FirCloner, + record_id: ExprId, + index: usize, + ty: &Ty, + span: Span, +) -> ExprId { + let field_id = cloner.alloc_expr(); + package.exprs.insert( + field_id, + Expr { + id: field_id, + span, + ty: ty.clone(), + kind: ExprKind::Field( + record_id, + Field::Path(FieldPath { + indices: vec![index], + }), + ), + exec_graph_range: EMPTY_EXEC_RANGE, + }, + ); + field_id +} + +/// Recursively resolves `Ty::Udt` references to their pure types. +/// +/// Uses the pre-built [`UdtCache`] for O(1) cross-package lookups and +/// recursively resolves embedded tuple, array, and arrow types so the +/// returned `Ty` is fully UDT-free. +fn resolve_ty(cache: &UdtCache, ty: &Ty) -> Ty { + match ty { + Ty::Udt(Res::Item(item_id)) => { + let key = (item_id.package, item_id.item); + if let Some(pure) = cache.get(&key) { + // The pure type itself may contain Ty::Udt (nested UDTs), + // so recurse. + resolve_ty(cache, pure) + } else { + ty.clone() + } + } + Ty::Array(elem) => { + let resolved = resolve_ty(cache, elem); + Ty::Array(Box::new(resolved)) + } + Ty::Tuple(elems) => { + let resolved: Vec = elems.iter().map(|e| resolve_ty(cache, e)).collect(); + Ty::Tuple(resolved) + } + Ty::Arrow(arrow) => { + let resolved_input = resolve_ty(cache, &arrow.input); + let resolved_output = resolve_ty(cache, &arrow.output); + Ty::Arrow(Box::new(Arrow { + kind: arrow.kind, + input: Box::new(resolved_input), + output: Box::new(resolved_output), + functors: arrow.functors, + })) + } + // Primitives, Param, Infer, Err — no UDT references to resolve. + _ => ty.clone(), + } +} diff --git a/source/compiler/qsc_fir_transforms/src/udt_erase/semantic_equivalence_tests.rs b/source/compiler/qsc_fir_transforms/src/udt_erase/semantic_equivalence_tests.rs new file mode 100644 index 0000000000..3305d0fc6b --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/udt_erase/semantic_equivalence_tests.rs @@ -0,0 +1,141 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use indoc::formatdoc; +use indoc::indoc; +use proptest::prelude::*; + +#[test] +fn udt_construction_and_field_access_preserves_semantics() { + crate::test_utils::check_semantic_equivalence(indoc! {r#" + namespace Test { + struct Pair { X : Int, Y : Int } + + @EntryPoint() + function Main() : Int { + let p = new Pair { X = 5, Y = 3 }; + p.X - p.Y + } + } + "#}); +} + +#[test] +fn udt_returned_from_function_preserves_semantics() { + crate::test_utils::check_semantic_equivalence(indoc! {r#" + namespace Test { + struct Wrapper { Value : Int } + + function MakeWrapper(v : Int) : Wrapper { + new Wrapper { Value = v } + } + + @EntryPoint() + function Main() : Int { + let w = MakeWrapper(42); + w.Value + } + } + "#}); +} + +#[test] +fn nested_udt_preserves_semantics() { + crate::test_utils::check_semantic_equivalence(indoc! {r#" + namespace Test { + struct Inner { A : Int, B : Int } + struct Outer { First : Inner, Second : Int } + + @EntryPoint() + function Main() : Int { + let inner = new Inner { A = 10, B = 20 }; + let outer = new Outer { First = inner, Second = 30 }; + outer.First.A + outer.First.B + outer.Second + } + } + "#}); +} + +#[test] +fn pretty_print_after_udt_erase_is_non_empty() { + let source = indoc! {r#" + namespace Test { + struct Pair { X : Int, Y : Int } + + @EntryPoint() + function Main() : Int { + let p = new Pair { X = 1, Y = 2 }; + p.X + p.Y + } + } + "#}; + let (store, pkg_id) = + crate::test_utils::compile_and_run_pipeline_to(source, crate::PipelineStage::UdtErase); + let rendered = crate::pretty::write_package_qsharp(&store, pkg_id); + // After UDT erasure the rendered Q# replaces struct construction with + // tuple literals and uses `::Item` field access. Verify non-empty. + assert!( + !rendered.is_empty(), + "pretty-printed Q# after UDT erasure should not be empty" + ); +} + +fn udt_erasure_pattern() -> impl Strategy { + (1..=4usize, prop::bool::ANY).prop_map(|(field_count, use_copy_update)| { + let fields = (0..field_count) + .map(|field_index| format!("F{field_index} : Int")) + .collect::>() + .join(", "); + let assignments = (0..field_count) + .map(|field_index| format!("F{field_index} = {field_index}")) + .collect::>() + .join(", "); + + if use_copy_update { + let updated_field = field_count - 1; + let result = (0..field_count) + .map(|field_index| format!("updated.F{field_index}")) + .collect::>() + .join(" + "); + + formatdoc! {r#" + namespace Test {{ + struct Generated {{ {fields} }} + + @EntryPoint() + function Main() : Int {{ + let record = new Generated {{ {assignments} }}; + let updated = new Generated {{ ...record, F{updated_field} = 99 }}; + {result} + }} + }} + "#} + } else { + let result = (0..field_count) + .map(|field_index| format!("record.F{field_index}")) + .collect::>() + .join(" + "); + + formatdoc! {r#" + namespace Test {{ + struct Generated {{ {fields} }} + + @EntryPoint() + function Main() : Int {{ + let record = new Generated {{ {assignments} }}; + {result} + }} + }} + "#} + } + }) +} + +proptest! { + #![proptest_config(ProptestConfig::with_cases(50))] + + #[test] + fn udt_erasure_preserves_semantics(source in udt_erasure_pattern()) { + crate::test_utils::check_semantic_equivalence(&source); + } +} diff --git a/source/compiler/qsc_fir_transforms/src/udt_erase/tests.rs b/source/compiler/qsc_fir_transforms/src/udt_erase/tests.rs new file mode 100644 index 0000000000..6f27459e9c --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/udt_erase/tests.rs @@ -0,0 +1,1585 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use expect_test::{Expect, expect}; +use indoc::indoc; + +use super::*; +use qsc_data_structures::index_map::IndexMap; +use qsc_data_structures::span::Span; +use qsc_fir::fir::{ + Block, CallableDecl, CallableImpl, CallableKind, ExecGraph, Expr, ExprKind, Field, FieldAssign, + FieldPath, Ident, Item, ItemId, LocalVarId, NodeId, PackageLookup, Pat, PatKind, SpecDecl, + SpecImpl, Stmt, StmtId, StmtKind, Visibility, +}; +use qsc_fir::ty::{FunctorSet, FunctorSetValue, Prim, Udt, UdtDef, UdtDefKind, UdtField}; +use rustc_hash::FxHashMap; +use std::rc::Rc; + +use crate::EMPTY_EXEC_RANGE; + +fn default_span() -> Span { + Span::default() +} + +/// Creates a minimal UDT type item (like `newtype Pair = (Int, Double)`). +fn make_udt_item(item_id: LocalItemId, fields: Vec<(Option>, Ty)>) -> Item { + let def = if fields.len() == 1 { + UdtDef { + span: default_span(), + kind: UdtDefKind::Field(UdtField { + name_span: None, + name: fields[0].0.clone(), + ty: fields[0].1.clone(), + }), + } + } else { + UdtDef { + span: default_span(), + kind: UdtDefKind::Tuple( + fields + .into_iter() + .map(|(name, ty)| UdtDef { + span: default_span(), + kind: UdtDefKind::Field(UdtField { + name_span: None, + name, + ty, + }), + }) + .collect(), + ), + } + }; + let udt = Udt { + span: default_span(), + name: Rc::from("TestUdt"), + definition: def, + }; + Item { + id: item_id, + span: default_span(), + parent: None, + doc: Rc::from(""), + attrs: vec![], + visibility: Visibility::Public, + kind: ItemKind::Ty( + Ident { + id: LocalVarId::default(), + span: default_span(), + name: Rc::from("TestUdt"), + }, + udt, + ), + } +} + +/// Creates a store with one package containing the given items. +fn make_store_with_items(items: Vec) -> (PackageStore, PackageId) { + let pkg_id = PackageId::from(0usize); + let mut store = PackageStore::new(); + let mut package = Package { + items: IndexMap::new(), + entry: None, + entry_exec_graph: ExecGraph::default(), + blocks: IndexMap::new(), + exprs: IndexMap::new(), + pats: IndexMap::new(), + stmts: IndexMap::new(), + }; + for item in items { + package.items.insert(item.id, item); + } + store.insert(pkg_id, package); + (store, pkg_id) +} + +fn make_ident(name: &str) -> Ident { + Ident { + id: LocalVarId::default(), + span: default_span(), + name: Rc::from(name), + } +} + +fn make_empty_package() -> Package { + Package { + items: IndexMap::new(), + entry: None, + entry_exec_graph: ExecGraph::default(), + blocks: IndexMap::new(), + exprs: IndexMap::new(), + pats: IndexMap::new(), + stmts: IndexMap::new(), + } +} + +fn insert_unit_pat(package: &mut Package, pat_id: PatId) { + package.pats.insert( + pat_id, + Pat { + id: pat_id, + span: default_span(), + ty: Ty::UNIT, + kind: PatKind::Tuple(vec![]), + }, + ); +} + +fn insert_unit_expr(package: &mut Package, expr_id: ExprId) { + package.exprs.insert( + expr_id, + Expr { + id: expr_id, + span: default_span(), + ty: Ty::UNIT, + kind: ExprKind::Tuple(vec![]), + exec_graph_range: EMPTY_EXEC_RANGE, + }, + ); +} + +fn insert_bool_lit(package: &mut Package, expr_id: ExprId, value: bool) { + package.exprs.insert( + expr_id, + Expr { + id: expr_id, + span: default_span(), + ty: Ty::Prim(Prim::Bool), + kind: ExprKind::Lit(qsc_fir::fir::Lit::Bool(value)), + exec_graph_range: EMPTY_EXEC_RANGE, + }, + ); +} + +fn insert_struct_callable_package( + store: &mut PackageStore, + package_id: PackageId, + callable_name: &str, + bool_value: bool, +) -> (LocalItemId, LocalItemId, ExprId) { + let udt_item_id = LocalItemId::from(0usize); + let callable_item_id = LocalItemId::from(1usize); + let input_pat_id = PatId::from(0usize); + let value_expr_id = ExprId::from(0usize); + let struct_expr_id = ExprId::from(1usize); + let stmt_id = StmtId::from(0usize); + let block_id = BlockId::from(0usize); + + let mut package = make_empty_package(); + insert_unit_pat(&mut package, input_pat_id); + insert_bool_lit(&mut package, value_expr_id, bool_value); + + let udt_res = Res::Item(ItemId { + package: package_id, + item: udt_item_id, + }); + let udt_ty = Ty::Udt(udt_res); + + package.exprs.insert( + struct_expr_id, + Expr { + id: struct_expr_id, + span: default_span(), + ty: udt_ty.clone(), + kind: ExprKind::Struct( + udt_res, + None, + vec![FieldAssign { + id: NodeId::from(0usize), + span: default_span(), + field: Field::Path(FieldPath { indices: vec![0] }), + value: value_expr_id, + }], + ), + exec_graph_range: EMPTY_EXEC_RANGE, + }, + ); + package.stmts.insert( + stmt_id, + Stmt { + id: stmt_id, + span: default_span(), + kind: StmtKind::Expr(struct_expr_id), + exec_graph_range: EMPTY_EXEC_RANGE, + }, + ); + package.blocks.insert( + block_id, + Block { + id: block_id, + span: default_span(), + ty: udt_ty.clone(), + stmts: vec![stmt_id], + }, + ); + package.items.insert( + udt_item_id, + make_udt_item( + udt_item_id, + vec![(Some(Rc::from("Value")), Ty::Prim(Prim::Bool))], + ), + ); + package.items.insert( + callable_item_id, + Item { + id: callable_item_id, + span: default_span(), + parent: None, + doc: Rc::from(""), + attrs: vec![], + visibility: Visibility::Public, + kind: ItemKind::Callable(Box::new(CallableDecl { + id: NodeId::from(1usize), + span: default_span(), + kind: CallableKind::Function, + name: make_ident(callable_name), + generics: vec![], + input: input_pat_id, + output: udt_ty, + functors: FunctorSetValue::Empty, + implementation: CallableImpl::Spec(SpecImpl { + body: SpecDecl { + id: NodeId::from(2usize), + span: default_span(), + block: block_id, + input: Some(input_pat_id), + exec_graph: ExecGraph::default(), + }, + adj: None, + ctl: None, + ctl_adj: None, + }), + attrs: vec![], + })), + }, + ); + store.insert(package_id, package); + + (udt_item_id, callable_item_id, struct_expr_id) +} + +fn make_entry_package_for_external_callable( + callee_package_id: PackageId, + callee_item_id: LocalItemId, + callee_udt_item_id: LocalItemId, +) -> Package { + let mut package = make_empty_package(); + let unit_expr_id = ExprId::from(0usize); + let callee_expr_id = ExprId::from(1usize); + let call_expr_id = ExprId::from(2usize); + + let output_ty = Ty::Udt(Res::Item(ItemId { + package: callee_package_id, + item: callee_udt_item_id, + })); + + insert_unit_expr(&mut package, unit_expr_id); + package.exprs.insert( + callee_expr_id, + Expr { + id: callee_expr_id, + span: default_span(), + ty: Ty::Arrow(Box::new(Arrow { + kind: CallableKind::Function, + input: Box::new(Ty::UNIT), + output: Box::new(output_ty.clone()), + functors: FunctorSet::Value(FunctorSetValue::Empty), + })), + kind: ExprKind::Var( + Res::Item(ItemId { + package: callee_package_id, + item: callee_item_id, + }), + vec![], + ), + exec_graph_range: EMPTY_EXEC_RANGE, + }, + ); + package.exprs.insert( + call_expr_id, + Expr { + id: call_expr_id, + span: default_span(), + ty: output_ty, + kind: ExprKind::Call(callee_expr_id, unit_expr_id), + exec_graph_range: EMPTY_EXEC_RANGE, + }, + ); + package.entry = Some(call_expr_id); + + package +} + +#[test] +fn resolve_ty_replaces_udt_with_pure_type() { + let item_id = LocalItemId::from(0usize); + let udt_item = make_udt_item( + item_id, + vec![ + (Some(Rc::from("fst")), Ty::Prim(Prim::Int)), + (Some(Rc::from("snd")), Ty::Prim(Prim::Double)), + ], + ); + let (store, pkg_id) = make_store_with_items(vec![udt_item]); + let cache = build_udt_cache(&store); + + let udt_ty = Ty::Udt(Res::Item(ItemId { + package: pkg_id, + item: item_id, + })); + let resolved = resolve_ty(&cache, &udt_ty); + assert_eq!( + resolved, + Ty::Tuple(vec![Ty::Prim(Prim::Int), Ty::Prim(Prim::Double)]) + ); +} + +#[test] +fn resolve_ty_single_field_udt_unwraps() { + let item_id = LocalItemId::from(0usize); + let udt_item = make_udt_item(item_id, vec![(Some(Rc::from("val")), Ty::Prim(Prim::Int))]); + let (store, pkg_id) = make_store_with_items(vec![udt_item]); + let cache = build_udt_cache(&store); + + let udt_ty = Ty::Udt(Res::Item(ItemId { + package: pkg_id, + item: item_id, + })); + let resolved = resolve_ty(&cache, &udt_ty); + assert_eq!(resolved, Ty::Prim(Prim::Int)); +} + +#[test] +fn resolve_ty_handles_nested_udt() { + let inner_id = LocalItemId::from(0usize); + let outer_id = LocalItemId::from(1usize); + let pkg_id = PackageId::from(0usize); + + let inner_item = make_udt_item( + inner_id, + vec![ + (Some(Rc::from("a")), Ty::Prim(Prim::Int)), + (Some(Rc::from("b")), Ty::Prim(Prim::Int)), + ], + ); + // Outer UDT has one field of type Inner UDT + one Int. + let outer_fields = vec![ + ( + Some(Rc::from("inner")), + Ty::Udt(Res::Item(ItemId { + package: pkg_id, + item: inner_id, + })), + ), + (Some(Rc::from("extra")), Ty::Prim(Prim::Bool)), + ]; + let outer_item = make_udt_item(outer_id, outer_fields); + + let (store, _) = make_store_with_items(vec![inner_item, outer_item]); + let cache = build_udt_cache(&store); + + let outer_ty = Ty::Udt(Res::Item(ItemId { + package: pkg_id, + item: outer_id, + })); + let resolved = resolve_ty(&cache, &outer_ty); + assert_eq!( + resolved, + Ty::Tuple(vec![ + Ty::Tuple(vec![Ty::Prim(Prim::Int), Ty::Prim(Prim::Int)]), + Ty::Prim(Prim::Bool), + ]) + ); +} + +#[test] +fn resolve_ty_in_array() { + let item_id = LocalItemId::from(0usize); + let udt_item = make_udt_item( + item_id, + vec![(None, Ty::Prim(Prim::Int)), (None, Ty::Prim(Prim::Int))], + ); + let (store, pkg_id) = make_store_with_items(vec![udt_item]); + let cache = build_udt_cache(&store); + + let arr_ty = Ty::Array(Box::new(Ty::Udt(Res::Item(ItemId { + package: pkg_id, + item: item_id, + })))); + let resolved = resolve_ty(&cache, &arr_ty); + assert_eq!( + resolved, + Ty::Array(Box::new(Ty::Tuple(vec![ + Ty::Prim(Prim::Int), + Ty::Prim(Prim::Int) + ]))) + ); +} + +#[test] +fn resolve_ty_in_arrow() { + let item_id = LocalItemId::from(0usize); + let udt_item = make_udt_item( + item_id, + vec![(None, Ty::Prim(Prim::Int)), (None, Ty::Prim(Prim::Double))], + ); + let (store, pkg_id) = make_store_with_items(vec![udt_item]); + let cache = build_udt_cache(&store); + + let udt_ty = Ty::Udt(Res::Item(ItemId { + package: pkg_id, + item: item_id, + })); + let arrow_ty = Ty::Arrow(Box::new(Arrow { + kind: CallableKind::Operation, + input: Box::new(udt_ty), + output: Box::new(Ty::UNIT), + functors: FunctorSet::Value(FunctorSetValue::Empty), + })); + let resolved = resolve_ty(&cache, &arrow_ty); + let expected_input = Ty::Tuple(vec![Ty::Prim(Prim::Int), Ty::Prim(Prim::Double)]); + if let Ty::Arrow(a) = &resolved { + assert_eq!(*a.input, expected_input); + assert_eq!(*a.output, Ty::UNIT); + } else { + panic!("expected Arrow type"); + } +} + +/// Compiles Q# through defunctionalization, runs UDT erasure, and +/// returns a snapshot of callable signatures in the user package. +fn extract_types_after_erasure(source: &str) -> String { + use crate::test_utils::{PipelineStage, compile_and_run_pipeline_to}; + + let (mut store, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::Defunc); + let mut assigner = qsc_fir::assigner::Assigner::from_package(store.get(pkg_id)); + erase_udts(&mut store, pkg_id, &mut assigner); + + let package = store.get(pkg_id); + let reachable = crate::reachability::collect_reachable_from_entry(&store, pkg_id); + let mut lines: Vec = Vec::new(); + for store_id in &reachable { + if store_id.package != pkg_id { + continue; + } + let item = package.get_item(store_id.item); + if let ItemKind::Callable(decl) = &item.kind { + let pat = package.get_pat(decl.input); + lines.push(format!( + "{}: input={}, output={}", + decl.name.name, pat.ty, decl.output + )); + } + } + lines.sort(); + lines.join("\n") +} + +fn check_erasure(source: &str, expect: &Expect) { + expect.assert_eq(&extract_types_after_erasure(source)); +} + +fn find_callable_body_block(package: &Package, callable_name: &str) -> BlockId { + for item in package.items.values() { + if let ItemKind::Callable(decl) = &item.kind + && decl.name.name.as_ref() == callable_name + { + return match &decl.implementation { + CallableImpl::Spec(spec_impl) => spec_impl.body.block, + CallableImpl::SimulatableIntrinsic(spec) => spec.block, + CallableImpl::Intrinsic => continue, + }; + } + } + + panic!("callable '{callable_name}' not found"); +} + +fn local_names(package: &Package) -> FxHashMap { + package + .pats + .values() + .filter_map(|pat| match &pat.kind { + PatKind::Bind(ident) => Some((ident.id, ident.name.to_string())), + PatKind::Tuple(_) | PatKind::Discard => None, + }) + .collect() +} + +fn local_name(local_names: &FxHashMap, local_id: LocalVarId) -> String { + local_names + .get(&local_id) + .cloned() + .unwrap_or_else(|| format!("<{local_id:?}>")) +} + +fn format_pat_name(package: &Package, pat_id: PatId) -> String { + let pat = package.get_pat(pat_id); + match &pat.kind { + PatKind::Bind(ident) => ident.name.to_string(), + PatKind::Tuple(sub_pats) => format!( + "({})", + sub_pats + .iter() + .map(|&sub_pat_id| format_pat_name(package, sub_pat_id)) + .collect::>() + .join(", ") + ), + PatKind::Discard => "_".to_string(), + } +} + +fn describe_expr( + package: &Package, + expr_id: ExprId, + local_names: &FxHashMap, +) -> String { + let expr = package.get_expr(expr_id); + match &expr.kind { + ExprKind::Assign(lhs, rhs) => format!( + "Assign({}, {})", + describe_expr(package, *lhs, local_names), + describe_expr(package, *rhs, local_names) + ), + ExprKind::Field(target, field) => format!( + "Field({}, {field})", + describe_expr(package, *target, local_names) + ), + ExprKind::Lit(lit) => format!("Lit({lit:?})"), + ExprKind::Tuple(items) => format!( + "Tuple({})", + items + .iter() + .map(|&item_id| describe_expr(package, item_id, local_names)) + .collect::>() + .join(", ") + ), + ExprKind::Var(Res::Local(local_id), _) => { + format!("Var({})", local_name(local_names, *local_id)) + } + ExprKind::Var(res, _) => format!("Var({res})"), + _ => crate::test_utils::expr_kind_short(package, expr_id), + } +} + +fn callable_local_summaries_after_erasure(source: &str, callable_name: &str) -> String { + use crate::test_utils::{PipelineStage, compile_and_run_pipeline_to}; + + let (store, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::UdtErase); + let package = store.get(pkg_id); + let block = package.get_block(find_callable_body_block(package, callable_name)); + let local_names = local_names(package); + + block + .stmts + .iter() + .filter_map(|&stmt_id| { + let stmt = package.get_stmt(stmt_id); + match &stmt.kind { + StmtKind::Local(mutability, pat_id, init_expr_id) => Some(format!( + "{mutability:?} {} = {}", + format_pat_name(package, *pat_id), + describe_expr(package, *init_expr_id, &local_names) + )), + _ => None, + } + }) + .collect::>() + .join("\n") +} + +fn callable_body_summary_after_erasure(source: &str, callable_name: &str) -> String { + use crate::test_utils::{PipelineStage, compile_and_run_pipeline_to}; + + let (store, pkg_id) = compile_and_run_pipeline_to(source, PipelineStage::UdtErase); + let package = store.get(pkg_id); + let block = package.get_block(find_callable_body_block(package, callable_name)); + let local_names = local_names(package); + + block + .stmts + .iter() + .enumerate() + .map(|(index, &stmt_id)| { + let stmt = package.get_stmt(stmt_id); + let summary = match &stmt.kind { + StmtKind::Expr(expr_id) => { + format!("Expr {}", describe_expr(package, *expr_id, &local_names)) + } + StmtKind::Semi(expr_id) => { + format!("Semi {}", describe_expr(package, *expr_id, &local_names)) + } + StmtKind::Local(mutability, pat_id, init_expr_id) => format!( + "Local {mutability:?} {} = {}", + format_pat_name(package, *pat_id), + describe_expr(package, *init_expr_id, &local_names) + ), + StmtKind::Item(local_item_id) => format!("Item {local_item_id}"), + }; + + format!("[{index}] {summary}") + }) + .collect::>() + .join("\n") +} + +fn main_local_summaries_after_erasure(source: &str) -> String { + callable_local_summaries_after_erasure(source, "Main") +} + +fn main_body_summary_after_erasure(source: &str) -> String { + callable_body_summary_after_erasure(source, "Main") +} + +fn check_callable_body_summary_after_erasure(source: &str, callable_name: &str, expect: &Expect) { + expect.assert_eq(&callable_body_summary_after_erasure(source, callable_name)); +} + +fn check_main_local_summaries_after_erasure(source: &str, expect: &Expect) { + expect.assert_eq(&main_local_summaries_after_erasure(source)); +} + +fn check_main_body_summary_after_erasure(source: &str, expect: &Expect) { + expect.assert_eq(&main_body_summary_after_erasure(source)); +} + +#[test] +fn simple_newtype_erased_to_inner_type() { + check_erasure( + indoc! {" + namespace Test { + newtype Wrapper = Int; + @EntryPoint() + function Main() : Unit { + let w = Wrapper(42); + } + } + "}, + &expect![[r#" + Main: input=Unit, output=Unit"#]], + ); +} + +#[test] +fn tuple_udt_erased_to_tuple() { + check_erasure( + indoc! {" + namespace Test { + newtype Pair = (Fst: Int, Snd: Double); + function MakePair() : (Int, Double) { + let p = Pair(1, 2.0); + (p::Fst, p::Snd) + } + @EntryPoint() + function Main() : Unit { + let _ = MakePair(); + } + } + "}, + &expect![[r#" + Main: input=Unit, output=Unit + MakePair: input=Unit, output=(Int, Double)"#]], + ); +} + +#[test] +fn nested_udt_erased_to_nested_tuple() { + check_erasure( + indoc! {" + namespace Test { + newtype Inner = (A: Int, B: Int); + newtype Outer = (First: Inner, Extra: Bool); + function MakeOuter() : Outer { + let i = Inner(1, 2); + Outer(i, true) + } + @EntryPoint() + function Main() : Unit { + let _ = MakeOuter(); + } + } + "}, + &expect![[r#" + Main: input=Unit, output=Unit + MakeOuter: input=Unit, output=((Int, Int), Bool)"#]], + ); +} + +/// Verifies that `p w/ Fst <- 42` on a two-field UDT is lowered to a +/// tuple construction after UDT erasure. The `PostUdtErase` invariant +/// check (run inside the pipeline) asserts that no +/// `UpdateField(_, Field::Path(_), _)` survives. +#[test] +fn udt_update_field_simple() { + check_main_local_summaries_after_erasure( + indoc! {" + namespace Test { + newtype Pair = (Fst: Int, Snd: Double); + @EntryPoint() + function Main() : Unit { + let p = Pair(1, 2.0); + let p2 = p w/ Fst <- 42; + } + } + "}, + &expect![[r#" + Immutable p = Tuple(Lit(Int(1)), Lit(Double(2.0))) + Immutable p2 = Tuple(Lit(Int(42)), Field(Var(p), Path([1])))"#]], + ); +} + +/// Verifies multi-level path lowering: `f w/ b <- 3.14` on a UDT with +/// nested anonymous tuple `(a: Int, (b: Double, c: Bool))` produces +/// field path `[1, 0]` which must be recursively lowered. +#[test] +fn udt_update_field_nested() { + check_main_local_summaries_after_erasure( + indoc! {" + namespace Test { + newtype Foo = (a: Int, (b: Double, c: Bool)); + @EntryPoint() + function Main() : Unit { + let f = Foo(1, (2.0, true)); + let f2 = f w/ b <- 3.14; + } + } + "}, + &expect![[r#" + Immutable f = Tuple(Lit(Int(1)), Tuple(Lit(Double(2.0)), Lit(Bool(true)))) + Immutable f2 = Tuple(Field(Var(f), Path([0])), Tuple(Lit(Double(3.14)), Field(Field(Var(f), Path([1])), Path([1]))))"#]], + ); +} + +/// Verifies that `w w/ val <- 42` on a single-field UDT (where the +/// pure type is scalar, not a tuple) is lowered to the replacement +/// value directly. +#[test] +fn udt_update_field_single_field() { + check_main_local_summaries_after_erasure( + indoc! {" + namespace Test { + newtype Wrapper = (val: Int); + @EntryPoint() + function Main() : Unit { + let w = Wrapper(99); + let w2 = w w/ val <- 42; + } + } + "}, + &expect![[r#" + Immutable w = Lit(Int(99)) + Immutable w2 = Lit(Int(42))"#]], + ); +} + +/// Verifies that `set p w/= Fst <- 42` (`AssignField`) is lowered to +/// `Assign(p, Tuple(...))` after UDT erasure. +#[test] +fn udt_assign_field() { + check_main_body_summary_after_erasure( + indoc! {" + namespace Test { + newtype Pair = (Fst: Int, Snd: Double); + @EntryPoint() + function Main() : Unit { + mutable p = Pair(1, 2.0); + p w/= Fst <- 42; + } + } + "}, + &expect![[r#" + [0] Local Mutable p = Tuple(Lit(Int(1)), Lit(Double(2.0))) + [1] Semi Assign(Var(p), Tuple(Lit(Int(42)), Field(Var(p), Path([1]))))"#]], + ); +} + +/// Verifies that two successive `w/` updates are each independently +/// lowered into tuple constructions. +#[test] +fn udt_chained_update() { + check_main_local_summaries_after_erasure( + indoc! {" + namespace Test { + newtype Pair = (Fst: Int, Snd: Double); + @EntryPoint() + function Main() : Unit { + let p = Pair(1, 2.0); + let p2 = p w/ Fst <- 42; + let p3 = p2 w/ Snd <- 3.14; + } + } + "}, + &expect![[r#" + Immutable p = Tuple(Lit(Int(1)), Lit(Double(2.0))) + Immutable p2 = Tuple(Lit(Int(42)), Field(Var(p), Path([1]))) + Immutable p3 = Tuple(Field(Var(p2), Path([0])), Lit(Double(3.14)))"#]], + ); +} + +/// Verifies 3-level field path lowering: updating a deeply nested named +/// field within anonymous tuples exercises recursive `lower_update_field` +/// with a 3-element path `[1, 1, 0]`. +#[test] +fn udt_update_field_deeply_nested() { + check_main_local_summaries_after_erasure( + indoc! {" + namespace Test { + newtype Deep = (a: Int, (b: Bool, (c: Double, d: Int))); + @EntryPoint() + function Main() : Unit { + let f = Deep(1, (true, (2.0, 3))); + let f2 = f w/ c <- 3.14; + } + } + "}, + &expect![[r#" + Immutable f = Tuple(Lit(Int(1)), Tuple(Lit(Bool(true)), Tuple(Lit(Double(2.0)), Lit(Int(3))))) + Immutable f2 = Tuple(Field(Var(f), Path([0])), Tuple(Field(Field(Var(f), Path([1])), Path([0])), Tuple(Lit(Double(3.14)), Field(Field(Field(Var(f), Path([1])), Path([1])), Path([1])))))"#]], + ); +} + +/// Verifies `UpdateField` lowering when a UDT contains another UDT: +/// `Outer = (First: Inner, Extra: Bool)` where `Inner = (x: Int, y: Int)`. +/// Updating `Extra` (a top-level field) exercises single-level path +/// lowering on a record whose sub-elements are themselves tuples. +#[test] +fn udt_nested_udt_update() { + check_main_local_summaries_after_erasure( + indoc! {" + namespace Test { + newtype Inner = (x: Int, y: Int); + newtype Outer = (First: Inner, Extra: Bool); + @EntryPoint() + function Main() : Unit { + let i = Inner(1, 2); + let o = Outer(i, true); + let o2 = o w/ Extra <- false; + } + } + "}, + &expect![[r#" + Immutable i = Tuple(Lit(Int(1)), Lit(Int(2))) + Immutable o = Tuple(Var(i), Lit(Bool(true))) + Immutable o2 = Tuple(Field(Var(o), Path([0])), Lit(Bool(false)))"#]], + ); +} + +#[test] +fn resolve_ty_udt_with_array_field() { + // UDT with Int[] field: the array element type is unchanged but + // the UDT wrapper is erased. + let item_id = LocalItemId::from(0usize); + let udt_item = make_udt_item( + item_id, + vec![ + ( + Some(Rc::from("vals")), + Ty::Array(Box::new(Ty::Prim(Prim::Int))), + ), + (Some(Rc::from("flag")), Ty::Prim(Prim::Bool)), + ], + ); + let (store, pkg_id) = make_store_with_items(vec![udt_item]); + let cache = build_udt_cache(&store); + + let udt_ty = Ty::Udt(Res::Item(ItemId { + package: pkg_id, + item: item_id, + })); + let resolved = resolve_ty(&cache, &udt_ty); + assert_eq!( + resolved, + Ty::Tuple(vec![ + Ty::Array(Box::new(Ty::Prim(Prim::Int))), + Ty::Prim(Prim::Bool), + ]) + ); +} + +#[test] +fn udt_as_callable_parameter_type() { + // UDT in callable parameter position is erased to tuple. + check_erasure( + indoc! {" + namespace Test { + newtype Pair = (Fst: Int, Snd: Double); + function UsePair(p : Pair) : Int { p::Fst } + @EntryPoint() + function Main() : Unit { + let _ = UsePair(Pair(1, 2.0)); + } + } + "}, + &expect![[r#" + Main: input=Unit, output=Unit + UsePair: input=(Int, Double), output=Int"#]], + ); +} + +#[test] +fn udt_as_callable_return_type() { + // UDT in callable return type is erased to tuple. + check_erasure( + indoc! {" + namespace Test { + newtype Pair = (Fst: Int, Snd: Double); + function MakeIt() : Pair { Pair(1, 2.0) } + @EntryPoint() + function Main() : Unit { + let _ = MakeIt(); + } + } + "}, + &expect![[r#" + Main: input=Unit, output=Unit + MakeIt: input=Unit, output=(Int, Double)"#]], + ); +} + +#[test] +fn udt_zero_fields_erased_to_unit() { + // `newtype Marker = Unit` maps to a single-field UDT whose inner type + // is Unit. After erasure the type becomes Unit (scalar). + check_erasure( + indoc! {" + namespace Test { + newtype Marker = Unit; + @EntryPoint() + function Main() : Unit { + let m = Marker(()); + } + } + "}, + &expect![[r#" + Main: input=Unit, output=Unit"#]], + ); +} + +#[test] +fn udt_used_in_nested_callable() { + // UDT created and used inside a helper callable (not Main). + // The erasure should apply to all callables in the package. + check_erasure( + indoc! {" + namespace Test { + newtype Pair = (Fst: Int, Snd: Int); + function MakeAndSum(x : Int) : Int { + let p = Pair(x, x + 1); + p::Fst + p::Snd + } + @EntryPoint() + function Main() : Unit { + let _ = MakeAndSum(5); + } + } + "}, + &expect![[r#" + Main: input=Unit, output=Unit + MakeAndSum: input=Int, output=Int"#]], + ); +} + +#[test] +fn resolve_ty_udt_in_tuple() { + // `(MyPair, Int)` — the inner UDT within a tuple wrapper is resolved. + let item_id = LocalItemId::from(0usize); + let udt_item = make_udt_item( + item_id, + vec![ + (Some(Rc::from("a")), Ty::Prim(Prim::Int)), + (Some(Rc::from("b")), Ty::Prim(Prim::Int)), + ], + ); + let (store, pkg_id) = make_store_with_items(vec![udt_item]); + let cache = build_udt_cache(&store); + + let tuple_ty = Ty::Tuple(vec![ + Ty::Udt(Res::Item(ItemId { + package: pkg_id, + item: item_id, + })), + Ty::Prim(Prim::Bool), + ]); + let resolved = resolve_ty(&cache, &tuple_ty); + assert_eq!( + resolved, + Ty::Tuple(vec![ + Ty::Tuple(vec![Ty::Prim(Prim::Int), Ty::Prim(Prim::Int)]), + Ty::Prim(Prim::Bool), + ]) + ); +} + +#[test] +fn udt_copy_update_expression() { + // `p w/ Fst <- 10` on a two-field UDT should lower to an erased tuple + // that keeps the untouched field as a projection from the source value. + check_main_local_summaries_after_erasure( + indoc! {" + namespace Test { + newtype Pair = (Fst: Int, Snd: Int); + @EntryPoint() + function Main() : Unit { + let p = Pair(1, 2); + let p2 = p w/ Fst <- 10; + } + } + "}, + &expect![[r#" + Immutable p = Tuple(Lit(Int(1)), Lit(Int(2))) + Immutable p2 = Tuple(Lit(Int(10)), Field(Var(p), Path([1])))"#]], + ); +} + +/// Verifies that `new Pair { ...p, Fst = 42 }` on a two-field UDT is +/// lowered to a tuple with the replacement at index 0 after UDT erasure. +#[test] +fn udt_copy_update_single_field() { + check_main_local_summaries_after_erasure( + indoc! {" + namespace Test { + newtype Pair = (Fst: Int, Snd: Double); + @EntryPoint() + function Main() : Unit { + let p = Pair(1, 2.0); + let p2 = new Pair { ...p, Fst = 42 }; + } + } + "}, + &expect![[r#" + Immutable p = Tuple(Lit(Int(1)), Lit(Double(2.0))) + Immutable p2 = Tuple(Lit(Int(42)), Field(Var(p), Path([1])))"#]], + ); +} + +/// Verifies that `new Triple { ...t, A = 1, C = 3 }` on a three-field UDT +/// is lowered to a tuple with replacements at indices 0 and 2. +#[test] +fn udt_copy_update_multiple_fields() { + check_main_local_summaries_after_erasure( + indoc! {" + namespace Test { + newtype Triple = (A: Int, B: Double, C: Bool); + @EntryPoint() + function Main() : Unit { + let t = Triple(1, 2.0, false); + let t2 = new Triple { ...t, A = 10, C = true }; + } + } + "}, + &expect![[r#" + Immutable t = Tuple(Lit(Int(1)), Lit(Double(2.0)), Lit(Bool(false))) + Immutable t2 = Tuple(Lit(Int(10)), Field(Var(t), Path([1])), Lit(Bool(true)))"#]], + ); +} + +/// Verifies that copy-update on a single-field UDT is lowered to the scalar +/// replacement value directly. +#[test] +fn udt_copy_update_single_field_udt() { + check_main_local_summaries_after_erasure( + indoc! {" + namespace Test { + newtype Wrapper = (val: Int); + @EntryPoint() + function Main() : Unit { + let w = Wrapper(99); + let w2 = w w/ val <- 10; + } + } + "}, + &expect![[r#" + Immutable w = Lit(Int(99)) + Immutable w2 = Lit(Int(10))"#]], + ); +} + +/// Verifies copy-update on a UDT with nested UDT fields. Updating +/// a top-level field should produce a tuple with the replacement +/// and field extractions for the remaining fields. +#[test] +fn udt_copy_update_nested() { + check_main_local_summaries_after_erasure( + indoc! {" + namespace Test { + newtype Inner = (x: Int, y: Int); + newtype Outer = (First: Inner, Extra: Bool); + @EntryPoint() + function Main() : Unit { + let i = Inner(1, 2); + let o = Outer(i, true); + let o2 = new Outer { ...o, Extra = false }; + } + } + "}, + &expect![[r#" + Immutable i = Tuple(Lit(Int(1)), Lit(Int(2))) + Immutable o = Tuple(Var(i), Lit(Bool(true))) + Immutable o2 = Tuple(Field(Var(o), Path([0])), Lit(Bool(false)))"#]], + ); +} + +#[test] +fn zero_field_udt_erased_to_unit() { + // Zero-field struct: `struct Empty {}` — boundary condition for + // UDT erasure where the underlying type collapses to Unit. + check_erasure( + indoc! {" + struct Empty {} + + function Main() : Unit { + let e = new Empty {}; + } + "}, + &expect![[r#" + Main: input=Unit, output=Unit"#]], + ); +} + +#[test] +fn three_level_nested_udt_fully_erased() { + // 3-level nested UDTs: verifies recursive resolution cache handles + // Inner → Middle → Outer chain correctly. + check_erasure( + indoc! {" + struct Inner { X : Int } + struct Middle { I : Inner, Y : Double } + struct Outer { M : Middle, Z : Bool } + + function Main() : Int { + let o = new Outer { M = new Middle { I = new Inner { X = 42 }, Y = 1.0 }, Z = true }; + o.M.I.X + } + "}, + &expect![[r#" + Main: input=Unit, output=Int"#]], + ); +} + +#[test] +fn udt_as_callable_return_type_erased() { + // UDT used as the return type of a callable: the output type + // should be resolved from Ty::Udt to (Int, Double) tuple. + check_erasure( + indoc! {" + struct Pair { Fst : Int, Snd : Double } + + function MakePair(x : Int, y : Double) : Pair { + new Pair { Fst = x, Snd = y } + } + + function Main() : Int { + let p = MakePair(1, 2.0); + p.Fst + } + "}, + &expect![[r#" + Main: input=Unit, output=Int + MakePair: input=(Int, Double), output=(Int, Double)"#]], + ); +} + +#[test] +fn resolve_ty_cache_miss_returns_original_udt() { + // When a Ty::Udt references an item not present in the cache, + // resolve_ty returns the original type unchanged. This is a + // defensive code path — in practice, all UDT items should be + // present in the cache after build_udt_cache. + let item_id = LocalItemId::from(0usize); + let udt_item = make_udt_item( + item_id, + vec![ + (Some(Rc::from("a")), Ty::Prim(Prim::Int)), + (Some(Rc::from("b")), Ty::Prim(Prim::Double)), + ], + ); + let (store, pkg_id) = make_store_with_items(vec![udt_item]); + let cache = build_udt_cache(&store); + + // Reference a different package that has no UDT items in the cache. + let missing_pkg = PackageId::from(99usize); + let missing_ty = Ty::Udt(Res::Item(ItemId { + package: missing_pkg, + item: item_id, + })); + let resolved = resolve_ty(&cache, &missing_ty); + // Cache miss: original type returned unchanged. + assert_eq!(resolved, missing_ty); + + // Also verify a missing item within the same package. + let missing_item = LocalItemId::from(99usize); + let missing_ty2 = Ty::Udt(Res::Item(ItemId { + package: pkg_id, + item: missing_item, + })); + let resolved2 = resolve_ty(&cache, &missing_ty2); + assert_eq!(resolved2, missing_ty2); +} + +#[test] +fn erase_udts_rewrites_reachable_external_package_but_leaves_unreachable_package_untouched() { + let target_pkg_id = PackageId::from(1usize); + let reachable_pkg_id = PackageId::from(2usize); + let unreachable_pkg_id = PackageId::from(3usize); + + let mut store = PackageStore::new(); + let (reachable_udt_item_id, reachable_callable_item_id, reachable_struct_expr_id) = + insert_struct_callable_package(&mut store, reachable_pkg_id, "Reachable", true); + let (_unreachable_udt_item_id, _unreachable_callable_item_id, unreachable_struct_expr_id) = + insert_struct_callable_package(&mut store, unreachable_pkg_id, "Unreachable", false); + + store.insert( + target_pkg_id, + make_entry_package_for_external_callable( + reachable_pkg_id, + reachable_callable_item_id, + reachable_udt_item_id, + ), + ); + + let mut assigner = qsc_fir::assigner::Assigner::from_package(store.get(target_pkg_id)); + erase_udts(&mut store, target_pkg_id, &mut assigner); + crate::invariants::check( + &store, + target_pkg_id, + crate::invariants::InvariantLevel::PostUdtErase, + ); + + let target_package = store.get(target_pkg_id); + let entry_expr = target_package.get_expr(target_package.entry.expect("entry should exist")); + assert_eq!(entry_expr.ty, Ty::Prim(Prim::Bool)); + + let reachable_package = store.get(reachable_pkg_id); + let ItemKind::Callable(reachable_callable) = + &reachable_package.get_item(reachable_callable_item_id).kind + else { + panic!("reachable item should be callable"); + }; + assert_eq!(reachable_callable.output, Ty::Prim(Prim::Bool)); + let reachable_struct_expr = reachable_package.get_expr(reachable_struct_expr_id); + assert_eq!(reachable_struct_expr.ty, Ty::Prim(Prim::Bool)); + assert!( + !matches!(reachable_struct_expr.kind, ExprKind::Struct(_, _, _)), + "reachable external package should have struct expressions erased" + ); + + let unreachable_package = store.get(unreachable_pkg_id); + let ItemKind::Callable(unreachable_callable) = + &unreachable_package.get_item(LocalItemId::from(1usize)).kind + else { + panic!("unreachable item should be callable"); + }; + assert!( + matches!(unreachable_callable.output, Ty::Udt(_)), + "unreachable package callable output should remain untouched" + ); + let unreachable_struct_expr = unreachable_package.get_expr(unreachable_struct_expr_id); + assert!( + matches!(unreachable_struct_expr.kind, ExprKind::Struct(_, _, _)), + "unreachable package struct should remain untouched" + ); + assert!( + matches!(unreachable_struct_expr.ty, Ty::Udt(_)), + "unreachable package expression type should remain untouched" + ); +} + +#[test] +#[should_panic(expected = "contains Ty::Udt after UDT erasure")] +fn post_udt_erase_invariants_cover_reachable_external_packages() { + let target_pkg_id = PackageId::from(1usize); + let reachable_pkg_id = PackageId::from(2usize); + + let mut store = PackageStore::new(); + let (reachable_udt_item_id, reachable_callable_item_id, _reachable_struct_expr_id) = + insert_struct_callable_package(&mut store, reachable_pkg_id, "Reachable", true); + + store.insert( + target_pkg_id, + make_entry_package_for_external_callable( + reachable_pkg_id, + reachable_callable_item_id, + reachable_udt_item_id, + ), + ); + + let mut assigner = qsc_fir::assigner::Assigner::from_package(store.get(target_pkg_id)); + erase_udts(&mut store, target_pkg_id, &mut assigner); + + let reachable_package = store.get_mut(reachable_pkg_id); + let reachable_item = reachable_package + .items + .get_mut(reachable_callable_item_id) + .expect("reachable callable should exist"); + let ItemKind::Callable(reachable_callable) = &mut reachable_item.kind else { + panic!("reachable item should be callable"); + }; + reachable_callable.output = Ty::Udt(Res::Item(ItemId { + package: reachable_pkg_id, + item: reachable_udt_item_id, + })); + + crate::invariants::check( + &store, + target_pkg_id, + crate::invariants::InvariantLevel::PostUdtErase, + ); +} + +/// Single-field struct declared with struct syntax: `get_pure_ty` returns +/// `Tuple([Int])`, so UDT erase must keep the tuple wrapper rather than +/// unwrapping to scalar. The `PostAll` invariant checks pat/init type +/// alignment and would panic if the expression were incorrectly unwrapped. +#[test] +fn single_field_struct_passes_post_all_invariant() { + use crate::test_utils::{PipelineStage, compile_and_run_pipeline_to}; + + let _ = compile_and_run_pipeline_to( + indoc! {" + struct Single { Value : Int } + + function Main() : Int { + let s = new Single { Value = 42 }; + s.Value + } + "}, + PipelineStage::Full, + ); +} + +/// Single-field struct syntax has a constructor whose pure type is +/// `Tuple([T])`. UDT erase eliminates the constructor call while preserving +/// the tuple wrapper, so the full pipeline passes without type mismatches. +#[test] +fn single_field_struct_constructor_passes_post_all_invariant() { + use crate::test_utils::{PipelineStage, compile_and_run_pipeline_to}; + + let _ = compile_and_run_pipeline_to( + indoc! {" + struct Wrapper { Value : Int } + + function Main() : Int { + let w = new Wrapper { Value = 42 }; + 0 + } + "}, + PipelineStage::Full, + ); +} + +/// Single-field struct syntax produces `UdtDefKind::Tuple([Field])`. Verify +/// UDT erase keeps the tuple wrapper. +#[test] +fn single_field_struct_erased_to_tuple() { + check_main_local_summaries_after_erasure( + indoc! {" + namespace Test { + struct Wrapper { Value : Int } + @EntryPoint() + function Main() : Unit { + let w = new Wrapper { Value = 42 }; + } + } + "}, + &expect![[r#" + Immutable w = Tuple(Lit(Int(42)))"#]], + ); +} + +/// Single-field struct variant with a function returning the wrapper type: the +/// erased output type is `(Int,)` (single-element tuple), confirming +/// `UdtDefKind::Tuple([Field])` preserves the tuple wrapper in return position. +#[test] +fn single_field_struct_return_type_erased_to_single_element_tuple() { + check_erasure( + indoc! {" + namespace Test { + struct Wrapper { Value : Int } + function Make() : Wrapper { new Wrapper { Value = 42 } } + @EntryPoint() + function Main() : Unit { + let _ = Make(); + } + } + "}, + &expect![[r#" + Main: input=Unit, output=Unit + Make: input=Unit, output=(Int,)"#]], + ); +} + +/// Control test: non-trailing-comma single-field newtype `(Value : Int)` is +/// erased to scalar `Int` (not a single-element tuple), confirming the +/// `UdtDefKind::Field` → scalar unwrap path. +#[test] +fn non_trailing_comma_newtype_single_field_erased_to_scalar() { + check_erasure( + indoc! {" + namespace Test { + newtype Wrapper = (Value : Int); + function Make() : Wrapper { Wrapper(42) } + @EntryPoint() + function Main() : Unit { + let _ = Make(); + } + } + "}, + &expect![[r#" + Main: input=Unit, output=Unit + Make: input=Unit, output=Int"#]], + ); +} + +#[test] +fn scalar_erased_newtype_field_read_lowered() { + // Field read access on a scalar-erased single-field newtype should be + // lowered. For example: + // - `newtype Wrapper = (x: Int); function Extract(w: Wrapper) : Int { w::x }` + // - After UDT erasure: `w: Prim(Int)` and `w::x` should become just `w` + // - The PostUdtErase invariant requires Field::Path only on Ty::Tuple, + // so this lowering is necessary to satisfy the invariant. + check_callable_body_summary_after_erasure( + indoc! {" + namespace Test { + newtype Wrapper = (Value : Int); + function Extract(w : Wrapper) : Int { w::Value } + @EntryPoint() + function Main() : Unit { + let x = Wrapper(42); + let _ = Extract(x); + } + } + "}, + "Extract", + &expect![[r#" + [0] Expr Var(x)"#]], + ); +} + +#[test] +fn udt_erase_is_idempotent() { + let source = indoc! {" + namespace Test { + struct Pair { X : Int, Y : Int } + @EntryPoint() + function Main() : (Int, Int) { + let p = new Pair { X = 1, Y = 2 }; + (p.X, p.Y) + } + } + "}; + let (mut store, pkg_id) = + crate::test_utils::compile_and_run_pipeline_to(source, crate::PipelineStage::UdtErase); + let first = crate::pretty::write_package_qsharp(&store, pkg_id); + let mut assigner = qsc_fir::assigner::Assigner::from_package(store.get(pkg_id)); + erase_udts(&mut store, pkg_id, &mut assigner); + let second = crate::pretty::write_package_qsharp(&store, pkg_id); + assert_eq!(first, second, "udt_erase should be idempotent"); +} + +fn render_before_after_udt_erase(source: &str) -> (String, String) { + let (mut store, pkg_id) = + crate::test_utils::compile_and_run_pipeline_to(source, crate::PipelineStage::Defunc); + let before = crate::pretty::write_package_qsharp(&store, pkg_id); + let mut assigner = qsc_fir::assigner::Assigner::from_package(store.get(pkg_id)); + erase_udts(&mut store, pkg_id, &mut assigner); + let after = crate::pretty::write_package_qsharp(&store, pkg_id); + (before, after) +} + +fn check_before_after_udt_erase(source: &str, expect: &Expect) { + let (before, after) = render_before_after_udt_erase(source); + expect.assert_eq(&format!("BEFORE:\n{before}\nAFTER:\n{after}")); +} + +#[test] +fn before_after_udt_erasure_snapshot() { + check_before_after_udt_erase( + indoc! {" + namespace Test { + struct Pair { X : Int, Y : Int } + @EntryPoint() + function Main() : (Int, Int) { + let p = new Pair { X = 1, Y = 2 }; + (p.X, p.Y) + } + } + "}, + &expect![[r#" + BEFORE: + // namespace Test + newtype Pair = (Int, Int); + function Main() : (Int, Int) { + body { + let p : UDT < Item 1(Package 2) > = new Pair { + X = 1, + Y = 2 + }; + (p::X, p::Y) + } + } + // entry + Main() + + AFTER: + // namespace Test + newtype Pair = (Int, Int); + function Main() : (Int, Int) { + body { + let p : (Int, Int) = (1, 2); + (p::Item < 0 >, p::Item < 1 >) + } + } + // entry + Main() + "#]], // snapshot populated by UPDATE_EXPECT=1 + ); +} + +#[test] +fn unreachable_callable_in_reachable_package_is_erased() { + // Verify that Dead callable (not reachable from entry) still gets UDT erasure + // applied because UDT erasure operates at package granularity. + // This locks the package-granular contract against accidental narrowing. + use crate::test_utils::{PipelineStage, compile_and_run_pipeline_to}; + + let (store, pkg_id) = compile_and_run_pipeline_to( + indoc! {" + namespace Test { + @EntryPoint() + operation Main() : Int { + let p = new Pair { X = 1, Y = 2 }; + p.X + p.Y + } + struct Pair { X : Int, Y : Int } + // Dead — never called + operation Dead() : Int { + let p = new Pair { X = 3, Y = 4 }; + p.X + } + } + "}, + PipelineStage::UdtErase, + ); + + // Verify Dead callable still exists and has UDT forms erased. + let package = store.get(pkg_id); + let dead_exists = package.items.values().any( + |item| matches!(&item.kind, ItemKind::Callable(decl) if decl.name.name.as_ref() == "Dead"), + ); + assert!(dead_exists, "Dead should still exist (pre-DCE)"); + + // UDT type items remain in the package after erase_udts — they are only + // removed later by item_dce. Verify that the UDT type item is still present + // (confirming package-granular erasure covers the Dead callable's UDT usage + // without removing the type item itself). + let has_udt = package + .items + .values() + .any(|item| matches!(&item.kind, ItemKind::Ty(..))); + assert!( + has_udt, + "UDT type item should still exist after erase_udts (removed by item_dce later)" + ); +} diff --git a/source/compiler/qsc_fir_transforms/src/walk_utils.rs b/source/compiler/qsc_fir_transforms/src/walk_utils.rs new file mode 100644 index 0000000000..11f7df79bb --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/walk_utils.rs @@ -0,0 +1,378 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! Shared expression-tree walkers for FIR transform passes. +//! +//! Provides [`for_each_expr`], a closure-based pre-order walker that +//! eliminates duplicated `ExprKind` matching across transform modules. +//! +//! # Use classification +//! +//! The [`collect_uses_in_block`] and [`collect_uses_in_expr`] helpers +//! classify every occurrence of a [`LocalVarId`] as either a *field-only* +//! use or a *whole-value* use. Tuple-decomposing passes rely on that +//! distinction to decide whether a local can be scalarized safely. +//! +//! - A **"use"** is any expression that mentions the local: a +//! `Var(Res::Local(local))` read, a [`Closure`](ExprKind::Closure) +//! capture, or an assignment whose left-hand side resolves to the local. +//! - **Decomposable assignment.** When the right-hand side of an +//! `Assign(Var(local), Tuple(..))` is a tuple literal, the classifier +//! treats it as a field-only use: each tuple element flows into a +//! separate field so the local's whole value is not reconstituted. +//! - **Closure captures are whole-value.** [`ExprKind::Closure`] captures +//! carry the local by value, so the walkers never attempt to split them +//! even when the captured type is a tuple. +//! - **Non-`Path` `Field` access is whole-value.** A [`Field`] projection +//! that is not a `Field::Path` keeps the record value materialized and is +//! classified as a whole-value use. + +#[cfg(test)] +mod tests; + +use crate::fir_builder::functored_specs; +use qsc_fir::fir::{ + BlockId, CallableImpl, Expr, ExprId, ExprKind, Field, ItemKind, LocalItemId, LocalVarId, + Package, PackageLookup, Res, SpecDecl, SpecImpl, StmtKind, StringComponent, +}; +use rustc_hash::FxHashSet; + +/// Walks an expression tree in pre-order, invoking `visit` for each expression. +/// +/// Does not recurse into closure bodies: [`ExprKind::Closure`] is a leaf from +/// the walker's perspective, so a callable reached only through a closure +/// capture will not appear in the traversal. +pub fn for_each_expr(pkg: &Package, expr_id: ExprId, visit: &mut F) +where + F: FnMut(ExprId, &Expr), +{ + let expr = pkg.get_expr(expr_id); + visit(expr_id, expr); + walk_children(pkg, &expr.kind, visit); +} + +/// Walks all expressions within a block. +/// +/// Does not recurse into closure bodies; see [`for_each_expr`]. +pub fn for_each_expr_in_block(pkg: &Package, block_id: BlockId, visit: &mut F) +where + F: FnMut(ExprId, &Expr), +{ + let block = pkg.get_block(block_id); + for &stmt_id in &block.stmts { + let stmt = pkg.get_stmt(stmt_id); + match &stmt.kind { + StmtKind::Expr(e) | StmtKind::Semi(e) | StmtKind::Local(_, _, e) => { + for_each_expr(pkg, *e, visit); + } + StmtKind::Item(_) => {} + } + } +} + +/// Walks expressions in a callable implementation. +/// +/// Does not recurse into closure bodies; see [`for_each_expr`]. +pub fn for_each_expr_in_callable_impl(pkg: &Package, callable_impl: &CallableImpl, visit: &mut F) +where + F: FnMut(ExprId, &Expr), +{ + match callable_impl { + CallableImpl::Intrinsic => {} + CallableImpl::Spec(spec_impl) => { + for_each_expr_in_spec_impl(pkg, spec_impl, visit); + } + CallableImpl::SimulatableIntrinsic(spec_decl) => { + for_each_expr_in_spec_decl(pkg, spec_decl, visit); + } + } +} + +fn for_each_expr_in_spec_impl(pkg: &Package, spec_impl: &SpecImpl, visit: &mut F) +where + F: FnMut(ExprId, &Expr), +{ + for_each_expr_in_spec_decl(pkg, &spec_impl.body, visit); + for spec in functored_specs(spec_impl) { + for_each_expr_in_spec_decl(pkg, spec, visit); + } +} + +fn for_each_expr_in_spec_decl(pkg: &Package, spec_decl: &SpecDecl, visit: &mut F) +where + F: FnMut(ExprId, &Expr), +{ + for_each_expr_in_block(pkg, spec_decl.block, visit); +} + +/// Exhaustive match over all `ExprKind` variants. No wildcard arm — adding a +/// new variant to `ExprKind` will produce a compile error here. +/// +/// Does not recurse into closure bodies: `ExprKind::Closure` is matched as a +/// leaf alongside `Hole`, `Lit`, and `Var`. +fn walk_children(pkg: &Package, kind: &ExprKind, visit: &mut F) +where + F: FnMut(ExprId, &Expr), +{ + match kind { + ExprKind::Array(exprs) | ExprKind::ArrayLit(exprs) | ExprKind::Tuple(exprs) => { + for &e in exprs { + for_each_expr(pkg, e, visit); + } + } + ExprKind::ArrayRepeat(a, b) + | ExprKind::Assign(a, b) + | ExprKind::AssignOp(_, a, b) + | ExprKind::BinOp(_, a, b) + | ExprKind::Call(a, b) + | ExprKind::Index(a, b) + | ExprKind::AssignField(a, _, b) + | ExprKind::UpdateField(a, _, b) => { + for_each_expr(pkg, *a, visit); + for_each_expr(pkg, *b, visit); + } + ExprKind::AssignIndex(a, b, c) | ExprKind::UpdateIndex(a, b, c) => { + for_each_expr(pkg, *a, visit); + for_each_expr(pkg, *b, visit); + for_each_expr(pkg, *c, visit); + } + ExprKind::Block(block_id) => { + for_each_expr_in_block(pkg, *block_id, visit); + } + ExprKind::Closure(_, _) | ExprKind::Hole | ExprKind::Lit(_) | ExprKind::Var(_, _) => {} + ExprKind::Fail(e) | ExprKind::Field(e, _) | ExprKind::Return(e) | ExprKind::UnOp(_, e) => { + for_each_expr(pkg, *e, visit); + } + ExprKind::If(cond, body, otherwise) => { + for_each_expr(pkg, *cond, visit); + for_each_expr(pkg, *body, visit); + if let Some(e) = otherwise { + for_each_expr(pkg, *e, visit); + } + } + ExprKind::Range(start, step, end) => { + for e in [start, step, end].into_iter().flatten() { + for_each_expr(pkg, *e, visit); + } + } + ExprKind::Struct(_, copy, fields) => { + if let Some(c) = copy { + for_each_expr(pkg, *c, visit); + } + for fa in fields { + for_each_expr(pkg, fa.value, visit); + } + } + ExprKind::String(components) => { + for component in components { + if let StringComponent::Expr(e) = component { + for_each_expr(pkg, *e, visit); + } + } + } + ExprKind::While(cond, block) => { + for_each_expr(pkg, *cond, visit); + for_each_expr_in_block(pkg, *block, visit); + } + } +} + +/// Classifies uses of `local_id` in a block. +/// +/// Pushes `true` for field-only uses, `false` for whole-value uses. +pub(crate) fn collect_uses_in_block( + package: &Package, + block_id: BlockId, + local_id: LocalVarId, + uses: &mut Vec, +) { + let block = package.get_block(block_id); + for &stmt_id in &block.stmts { + let stmt = package.get_stmt(stmt_id); + match &stmt.kind { + StmtKind::Expr(e) | StmtKind::Semi(e) => { + collect_uses_in_expr(package, *e, local_id, uses, false); + } + StmtKind::Local(_, _, expr) => { + collect_uses_in_expr(package, *expr, local_id, uses, false); + } + StmtKind::Item(_) => {} + } + } +} + +/// Recursively classifies uses of `local_id` in an expression. +/// +/// `inside_field` is true when `expr_id` is the direct child of a +/// `Field(_, Path(_))` or non-empty `AssignField(_, Path(_), _)` — meaning the +/// variable reference is being used for field access. +pub(crate) fn collect_uses_in_expr( + package: &Package, + expr_id: ExprId, + local_id: LocalVarId, + uses: &mut Vec, + inside_field: bool, +) { + let expr = package.get_expr(expr_id); + match &expr.kind { + ExprKind::Var(Res::Local(var_id), _) if *var_id == local_id => { + uses.push(inside_field); + } + ExprKind::Field(inner, Field::Path(_)) => { + collect_uses_in_expr(package, *inner, local_id, uses, true); + } + ExprKind::AssignField(record, Field::Path(path), value) if !path.indices.is_empty() => { + collect_uses_in_expr(package, *record, local_id, uses, true); + collect_uses_in_expr(package, *value, local_id, uses, false); + } + ExprKind::Array(es) | ExprKind::ArrayLit(es) | ExprKind::Tuple(es) => { + for &e in es { + collect_uses_in_expr(package, e, local_id, uses, false); + } + } + ExprKind::Assign(a, b) => { + let lhs_expr = package.get_expr(*a); + let rhs_expr = package.get_expr(*b); + if let ExprKind::Var(Res::Local(var_id), _) = &lhs_expr.kind + && *var_id == local_id + && matches!(rhs_expr.kind, ExprKind::Tuple(_)) + { + // Whole-tuple assignment with tuple literal RHS: treat as decomposable. + uses.push(true); + // Walk RHS elements for any uses of local_id. + if let ExprKind::Tuple(elements) = &rhs_expr.kind { + for &e in elements { + collect_uses_in_expr(package, e, local_id, uses, false); + } + } + } else { + collect_uses_in_expr(package, *a, local_id, uses, false); + collect_uses_in_expr(package, *b, local_id, uses, false); + } + } + ExprKind::ArrayRepeat(a, b) + | ExprKind::AssignOp(_, a, b) + | ExprKind::BinOp(_, a, b) + | ExprKind::Call(a, b) + | ExprKind::Index(a, b) + | ExprKind::AssignField(a, _, b) + | ExprKind::UpdateField(a, _, b) => { + collect_uses_in_expr(package, *a, local_id, uses, false); + collect_uses_in_expr(package, *b, local_id, uses, false); + } + ExprKind::AssignIndex(a, b, c) | ExprKind::UpdateIndex(a, b, c) => { + collect_uses_in_expr(package, *a, local_id, uses, false); + collect_uses_in_expr(package, *b, local_id, uses, false); + collect_uses_in_expr(package, *c, local_id, uses, false); + } + ExprKind::Block(block_id) => { + collect_uses_in_block(package, *block_id, local_id, uses); + } + ExprKind::Fail(e) | ExprKind::Return(e) | ExprKind::UnOp(_, e) => { + collect_uses_in_expr(package, *e, local_id, uses, false); + } + ExprKind::Field(inner, _) => { + collect_uses_in_expr(package, *inner, local_id, uses, false); + } + ExprKind::If(cond, body, otherwise) => { + collect_uses_in_expr(package, *cond, local_id, uses, false); + collect_uses_in_expr(package, *body, local_id, uses, false); + if let Some(e) = otherwise { + collect_uses_in_expr(package, *e, local_id, uses, false); + } + } + ExprKind::Range(s, st, e) => { + for x in [s, st, e].into_iter().flatten() { + collect_uses_in_expr(package, *x, local_id, uses, false); + } + } + ExprKind::Struct(_, copy, fields) => { + if let Some(c) = copy { + collect_uses_in_expr(package, *c, local_id, uses, false); + } + for fa in fields { + collect_uses_in_expr(package, fa.value, local_id, uses, false); + } + } + ExprKind::String(components) => { + for c in components { + if let qsc_fir::fir::StringComponent::Expr(e) = c { + collect_uses_in_expr(package, *e, local_id, uses, false); + } + } + } + ExprKind::While(cond, block_id) => { + collect_uses_in_expr(package, *cond, local_id, uses, false); + collect_uses_in_block(package, *block_id, local_id, uses); + } + ExprKind::Closure(vars, _) => { + if vars.contains(&local_id) { + uses.push(false); + } + } + ExprKind::Hole | ExprKind::Lit(_) | ExprKind::Var(_, _) => {} + } +} + +/// Collects all expression IDs reachable from the package entry expression. +/// +/// Returns an empty vector when the package has no entry. +pub(crate) fn collect_expr_ids_in_entry(package: &Package) -> Vec { + let mut ids = Vec::new(); + let mut seen = FxHashSet::default(); + if let Some(entry_id) = package.entry { + for_each_expr(package, entry_id, &mut |expr_id, _| { + if seen.insert(expr_id) { + ids.push(expr_id); + } + }); + } + ids +} + +/// Collects all expression IDs from the specialization bodies of the given +/// local callables. +pub(crate) fn collect_expr_ids_in_local_callables( + package: &Package, + local_item_ids: &[LocalItemId], +) -> Vec { + let mut ids = Vec::new(); + let mut seen = FxHashSet::default(); + extend_expr_ids_in_local_callables(package, local_item_ids, &mut ids, &mut seen); + ids +} + +/// Collects all expression IDs from the entry expression and the specialization +/// bodies of the given local callables. +pub(crate) fn collect_expr_ids_in_entry_and_local_callables( + package: &Package, + local_item_ids: &[LocalItemId], +) -> Vec { + let mut ids = collect_expr_ids_in_entry(package); + let mut seen: FxHashSet = ids.iter().copied().collect(); + extend_expr_ids_in_local_callables(package, local_item_ids, &mut ids, &mut seen); + ids +} + +/// Extends an existing expression ID collection with IDs from the given local +/// callable bodies. Skips IDs already in `seen`. +pub(crate) fn extend_expr_ids_in_local_callables( + package: &Package, + local_item_ids: &[LocalItemId], + ids: &mut Vec, + seen: &mut FxHashSet, +) { + for &local_item_id in local_item_ids { + let Some(item) = package.items.get(local_item_id) else { + continue; + }; + let ItemKind::Callable(decl) = &item.kind else { + continue; + }; + for_each_expr_in_callable_impl(package, &decl.implementation, &mut |expr_id, _| { + if seen.insert(expr_id) { + ids.push(expr_id); + } + }); + } +} diff --git a/source/compiler/qsc_fir_transforms/src/walk_utils/tests.rs b/source/compiler/qsc_fir_transforms/src/walk_utils/tests.rs new file mode 100644 index 0000000000..2758b059ac --- /dev/null +++ b/source/compiler/qsc_fir_transforms/src/walk_utils/tests.rs @@ -0,0 +1,384 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use super::*; +use crate::test_utils::compile_to_fir; +use expect_test::expect; +use qsc_fir::assigner::Assigner; +use qsc_fir::fir::{CallableImpl, ItemKind, PatKind}; + +/// Finds the body block of the named callable in the user package. +fn find_callable_block(package: &Package, name: &str) -> BlockId { + for item in package.items.values() { + if let ItemKind::Callable(decl) = &item.kind + && decl.name.name.as_ref() == name + && let CallableImpl::Spec(spec) = &decl.implementation + { + return spec.body.block; + } + } + panic!("callable '{name}' not found"); +} + +/// Finds the `LocalVarId` for the first pattern binding with the given name. +fn find_local_var(package: &Package, name: &str) -> LocalVarId { + for pat in package.pats.values() { + if let PatKind::Bind(ident) = &pat.kind + && ident.name.as_ref() == name + { + return ident.id; + } + } + panic!("local var '{name}' not found"); +} + +#[test] +fn field_only_access_classified_as_field_use() { + let (store, pkg_id) = compile_to_fir( + "struct Pair { X : Int, Y : Int } + function Main() : Int { + let p = new Pair { X = 1, Y = 2 }; + p.X + p.Y + }", + ); + let package = store.get(pkg_id); + let block_id = find_callable_block(package, "Main"); + let local_id = find_local_var(package, "p"); + let mut uses = Vec::new(); + collect_uses_in_block(package, block_id, local_id, &mut uses); + // Both p.X and p.Y are field-only accesses. + expect![[r#" + [ + true, + true, + ] + "#]] + .assert_debug_eq(&uses); +} + +#[test] +fn whole_value_use_as_function_argument() { + let (store, pkg_id) = compile_to_fir( + "function Consume(t : (Int, Int)) : Int { + let (a, b) = t; + a + b + } + function Main() : Int { + let t = (1, 2); + Consume(t) + }", + ); + let package = store.get(pkg_id); + let block_id = find_callable_block(package, "Main"); + let local_id = find_local_var(package, "t"); + let mut uses = Vec::new(); + collect_uses_in_block(package, block_id, local_id, &mut uses); + // t is passed directly to Consume — whole-value use. + expect![[r#" + [ + false, + ] + "#]] + .assert_debug_eq(&uses); +} + +#[test] +fn decomposable_assign_tuple_literal_rhs() { + let (store, pkg_id) = compile_to_fir( + "function Main() : (Int, Int) { + mutable t = (1, 2); + t = (3, 4); + t + }", + ); + let package = store.get(pkg_id); + let block_id = find_callable_block(package, "Main"); + let local_id = find_local_var(package, "t"); + let mut uses = Vec::new(); + collect_uses_in_block(package, block_id, local_id, &mut uses); + // set t = (3, 4) is decomposable (true), final `t` is whole-value (false). + expect![[r#" + [ + true, + false, + ] + "#]] + .assert_debug_eq(&uses); +} + +#[test] +fn closure_capture_classified_as_whole_use() { + let (store, pkg_id) = compile_to_fir( + "function Apply(f : Int -> Int, x : Int) : Int { f(x) } + function Main() : Int { + let y = 5; + let f = x -> x + y; + Apply(f, 10) + }", + ); + let package = store.get(pkg_id); + let block_id = find_callable_block(package, "Main"); + let local_id = find_local_var(package, "y"); + let mut uses = Vec::new(); + collect_uses_in_block(package, block_id, local_id, &mut uses); + // y is captured by the closure — whole-value use. + expect![[r#" + [ + false, + ] + "#]] + .assert_debug_eq(&uses); +} + +#[test] +fn nested_field_access_classified_as_field_use() { + let (store, pkg_id) = compile_to_fir( + "struct Inner { X : Int } + struct Outer { I : Inner } + function Main() : Int { + let o = new Outer { I = new Inner { X = 42 } }; + o.I.X + }", + ); + let package = store.get(pkg_id); + let block_id = find_callable_block(package, "Main"); + let local_id = find_local_var(package, "o"); + let mut uses = Vec::new(); + collect_uses_in_block(package, block_id, local_id, &mut uses); + // o.I.X is a nested field access — still field-only. + expect![[r#" + [ + true, + ] + "#]] + .assert_debug_eq(&uses); +} + +#[test] +fn walker_visits_nested_expression_kinds_in_program() { + let (store, pkg_id) = compile_to_fir( + "function Main() : Int { + let x = 1 + 2; + let t = (x, 3); + if x > 0 { 10 } else { 20 } + }", + ); + let package = store.get(pkg_id); + let block_id = find_callable_block(package, "Main"); + + let mut kinds: Vec = Vec::new(); + for_each_expr_in_block(package, block_id, &mut |_id, expr| { + let kind_str = match &expr.kind { + ExprKind::Array(_) => "Array", + ExprKind::ArrayLit(_) => "ArrayLit", + ExprKind::ArrayRepeat(_, _) => "ArrayRepeat", + ExprKind::Assign(_, _) => "Assign", + ExprKind::AssignOp(_, _, _) => "AssignOp", + ExprKind::AssignField(_, _, _) => "AssignField", + ExprKind::AssignIndex(_, _, _) => "AssignIndex", + ExprKind::BinOp(_, _, _) => "BinOp", + ExprKind::Block(_) => "Block", + ExprKind::Call(_, _) => "Call", + ExprKind::Closure(_, _) => "Closure", + ExprKind::Fail(_) => "Fail", + ExprKind::Field(_, _) => "Field", + ExprKind::Hole => "Hole", + ExprKind::If(_, _, _) => "If", + ExprKind::Index(_, _) => "Index", + ExprKind::Lit(_) => "Lit", + ExprKind::Range(_, _, _) => "Range", + ExprKind::Return(_) => "Return", + ExprKind::Struct(_, _, _) => "Struct", + ExprKind::String(_) => "String", + ExprKind::UpdateIndex(_, _, _) => "UpdateIndex", + ExprKind::Tuple(_) => "Tuple", + ExprKind::UnOp(_, _) => "UnOp", + ExprKind::UpdateField(_, _, _) => "UpdateField", + ExprKind::Var(_, _) => "Var", + ExprKind::While(_, _) => "While", + }; + kinds.push(kind_str.to_string()); + }); + kinds.sort(); + expect![[r#" + [ + "BinOp", + "BinOp", + "Block", + "Block", + "If", + "Lit", + "Lit", + "Lit", + "Lit", + "Lit", + "Lit", + "Tuple", + "Var", + "Var", + ] + "#]] + .assert_debug_eq(&kinds); +} + +#[test] +fn assigner_ids_do_not_collide_with_existing_package_ids() { + let (store, pkg_id) = compile_to_fir("function Main() : Int { 1 + 2 }"); + let package = store.get(pkg_id); + let mut assigner = Assigner::from_package(package); + + // Assigner::from_package tracks expr, stmt, pat, and local IDs. + let new_expr = assigner.next_expr(); + let new_stmt = assigner.next_stmt(); + let new_pat = assigner.next_pat(); + let new_local = assigner.next_local(); + + // Verify allocated IDs are strictly beyond all existing IDs. + let max_expr = package + .exprs + .iter() + .map(|(id, _)| u32::from(id)) + .max() + .unwrap_or(0); + let max_stmt = package + .stmts + .iter() + .map(|(id, _)| u32::from(id)) + .max() + .unwrap_or(0); + let max_pat = package + .pats + .iter() + .map(|(id, _)| u32::from(id)) + .max() + .unwrap_or(0); + + let mut max_local: u32 = 0; + for pat in package.pats.values() { + if let PatKind::Bind(ident) = &pat.kind { + max_local = max_local.max(u32::from(ident.id)); + } + } + + assert!( + u32::from(new_expr) > max_expr, + "new expr {new_expr} should be > max existing {max_expr}" + ); + assert!( + u32::from(new_stmt) > max_stmt, + "new stmt {new_stmt} should be > max existing {max_stmt}" + ); + assert!( + u32::from(new_pat) > max_pat, + "new pat {new_pat} should be > max existing {max_pat}" + ); + assert!( + u32::from(new_local) > max_local, + "new local {new_local} should be > max existing {max_local}" + ); +} + +#[test] +fn collect_entry_expr_ids_returns_all_entry_descendants() { + let (store, pkg_id) = compile_to_fir( + "function Main() : Int { + let x = 1 + 2; + x + }", + ); + let package = store.get(pkg_id); + let ids = collect_expr_ids_in_entry(package); + // The entry expression wraps the call to Main. It should contain at least + // the call expression and the callee/args sub-expressions. + assert!( + !ids.is_empty(), + "entry expression IDs should be non-empty for a program with an entry point" + ); + // All returned IDs should be valid expression IDs in the package. + for &id in &ids { + let _ = package.get_expr(id); + } +} + +#[test] +fn collect_callable_expr_ids_covers_all_specs() { + let (store, pkg_id) = compile_to_fir( + "operation Op() : Unit is Adj + Ctl { + body ... { Message(\"body\"); } + adjoint ... { Message(\"adj\"); } + controlled (cs, ...) { Message(\"ctl\"); } + } + operation Main() : Unit { Op(); }", + ); + let package = store.get(pkg_id); + + // Find Op's LocalItemId. + let op_local_id = package + .items + .iter() + .find_map(|(id, item)| { + if let ItemKind::Callable(decl) = &item.kind { + if decl.name.name.as_ref() == "Op" { + return Some(id); + } + } + None + }) + .expect("Op callable not found"); + + let ids = collect_expr_ids_in_local_callables(package, &[op_local_id]); + // Op has body, adj, and ctl specs — each contains at least a Call expression. + assert!( + ids.len() >= 3, + "expected at least 3 expression IDs covering multiple specs, got {}", + ids.len() + ); + // No duplicates. + let unique: FxHashSet<_> = ids.iter().copied().collect(); + assert_eq!(ids.len(), unique.len(), "expression IDs should be unique"); +} + +#[test] +fn extend_does_not_duplicate_seen_ids() { + let (store, pkg_id) = compile_to_fir( + "function Helper() : Int { 42 } + function Main() : Int { Helper() }", + ); + let package = store.get(pkg_id); + + // Collect all local callable IDs. + let local_ids: Vec<_> = package + .items + .iter() + .filter_map(|(id, item)| { + if let ItemKind::Callable(_) = &item.kind { + Some(id) + } else { + None + } + }) + .collect(); + + // First collection. + let mut ids = Vec::new(); + let mut seen = FxHashSet::default(); + extend_expr_ids_in_local_callables(package, &local_ids, &mut ids, &mut seen); + let first_count = ids.len(); + assert!(first_count > 0, "should collect some expression IDs"); + + // Second extension with same callables — should add nothing. + extend_expr_ids_in_local_callables(package, &local_ids, &mut ids, &mut seen); + assert_eq!( + ids.len(), + first_count, + "second extension should not add duplicates" + ); +} + +#[test] +fn empty_local_items_returns_empty() { + let (store, pkg_id) = compile_to_fir("function Main() : Int { 1 }"); + let package = store.get(pkg_id); + let ids = collect_expr_ids_in_local_callables(package, &[]); + assert!(ids.is_empty(), "empty item list should yield empty result"); +} diff --git a/source/compiler/qsc_fir_transforms/tests/pipeline_contracts.rs b/source/compiler/qsc_fir_transforms/tests/pipeline_contracts.rs new file mode 100644 index 0000000000..1535fd5a18 --- /dev/null +++ b/source/compiler/qsc_fir_transforms/tests/pipeline_contracts.rs @@ -0,0 +1,147 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! Contract tests that validate `run_pipeline` output satisfies the `PostAll` +//! invariants expected by downstream consumers (codegen, language service, RCA). +//! +//! Each test compiles a representative Q# program, runs the full FIR transform +//! pipeline, and then calls [`invariants::check`] with [`InvariantLevel::PostAll`] +//! to assert that all structural postconditions hold. +//! +//! These tests are intentionally kept separate from the stage-parity tests in +//! `pipeline_integration.rs` so that contract regressions are easy to triage: +//! a failure here means a downstream consumer may receive malformed FIR. +//! +//! ## Compilation pattern +//! +//! Tests use `compile_to_fir` with `@EntryPoint()` in the source (the same +//! pattern as `pipeline_integration.rs`). This produces a package with a +//! concrete `entry` expression so that `invariants::check` runs the full +//! reachability-based checks rather than returning early. + +use qsc_fir_transforms::{ + invariants, run_pipeline, + test_utils::{assert_no_pipeline_errors, compile_to_fir}, +}; + +// --------------------------------------------------------------------------- +// Helper +// --------------------------------------------------------------------------- + +/// Compiles `source` (which must contain `@EntryPoint()`) through the full FIR +/// transform pipeline and returns the store + package id. +/// +/// Panics if the pipeline reports any errors. +fn compile_and_run_full_pipeline( + source: &str, +) -> (qsc_fir::fir::PackageStore, qsc_fir::fir::PackageId) { + let (mut store, pkg_id) = compile_to_fir(source); + let errors = run_pipeline(&mut store, pkg_id); + assert_no_pipeline_errors("run_pipeline", &errors); + (store, pkg_id) +} + +// --------------------------------------------------------------------------- +// PostAll invariant contract tests +// --------------------------------------------------------------------------- + +/// Core contract test: verifies that `run_pipeline` output on a minimal entry +/// point satisfies the full `PostAll` invariant suite expected by downstream +/// consumers (codegen, language service, RCA). +/// +/// Postconditions asserted by `InvariantLevel::PostAll`: +/// - No `Ty::Param` in reachable code (monomorphization completed). +/// - No `ExprKind::Return` in reachable code (return unification completed). +/// - No `Ty::Arrow` params / `ExprKind::Closure` (defunctionalization completed). +/// - No `Ty::Udt` / `ExprKind::Struct` / `Field::Path` (UDT erasure completed). +/// - All exec-graph ranges populated (exec-graph rebuild completed). +#[test] +fn run_pipeline_output_satisfies_post_all_invariants() { + let (store, pkg_id) = compile_and_run_full_pipeline( + r#" + @EntryPoint() + operation Main() : Int { 42 } + "#, + ); + + // Panics with a descriptive message if any `PostAll` invariant is violated. + invariants::check(&store, pkg_id, invariants::InvariantLevel::PostAll); +} + +/// Verifies that a program using higher-order functions satisfies `PostAll` +/// invariants -- exercises the defunctionalization contract specifically. +#[test] +fn run_pipeline_defunctionalized_output_satisfies_post_all_invariants() { + let (store, pkg_id) = compile_and_run_full_pipeline( + r#" + operation Apply(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + + @EntryPoint() + operation Main() : Unit { + use q = Qubit(); + Apply(H, q); + Reset(q); + } + "#, + ); + + invariants::check(&store, pkg_id, invariants::InvariantLevel::PostAll); +} + +/// Verifies that a program with early returns satisfies `PostAll` invariants -- +/// exercises the return-unification contract specifically. +#[test] +fn run_pipeline_return_unified_output_satisfies_post_all_invariants() { + let (store, pkg_id) = compile_and_run_full_pipeline( + r#" + operation EarlyReturn(flag : Bool) : Int { + if flag { return 1; } + 0 + } + + @EntryPoint() + operation Main() : Int { + EarlyReturn(true) + } + "#, + ); + + invariants::check(&store, pkg_id, invariants::InvariantLevel::PostAll); +} + +/// Verifies that a program using user-defined types satisfies `PostAll` +/// invariants -- exercises the UDT erasure contract specifically. +#[test] +fn run_pipeline_udt_erased_output_satisfies_post_all_invariants() { + let (store, pkg_id) = compile_and_run_full_pipeline( + r#" + newtype Pair = (First : Int, Second : Int); + + @EntryPoint() + operation Main() : Int { + let p = Pair(1, 2); + p::First + } + "#, + ); + + invariants::check(&store, pkg_id, invariants::InvariantLevel::PostAll); +} + +/// Verifies that a program with generic functions satisfies `PostAll` invariants +/// -- exercises the monomorphization contract specifically. +#[test] +fn run_pipeline_monomorphized_output_satisfies_post_all_invariants() { + let (store, pkg_id) = compile_and_run_full_pipeline( + r#" + function Identity<'T>(x : 'T) : 'T { x } + + @EntryPoint() + operation Main() : Int { Identity(42) } + "#, + ); + + invariants::check(&store, pkg_id, invariants::InvariantLevel::PostAll); +} diff --git a/source/compiler/qsc_fir_transforms/tests/pipeline_integration.rs b/source/compiler/qsc_fir_transforms/tests/pipeline_integration.rs new file mode 100644 index 0000000000..b3a2de2834 --- /dev/null +++ b/source/compiler/qsc_fir_transforms/tests/pipeline_integration.rs @@ -0,0 +1,1683 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! Integration tests that compile Q# source through the full FIR optimization +//! pipeline and cover schedule parity, successful end-to-end validation, and +//! targeted failure regressions. + +use qsc_fir::{ + fir::{CallableImpl, ExprKind, ItemKind, PackageLookup}, + validate::validate, +}; +use qsc_fir_transforms::{ + PipelineError, PipelineStage, invariants, reachability, run_pipeline, run_pipeline_to, + test_utils::{ + assert_callable_body_terminal_expr_matches_block_type, assert_no_pipeline_errors, + compile_to_fir, compile_to_fir_with_entry, expr_kind_short, + }, +}; + +type LoweredOutput = ( + qsc_fir::fir::PackageStore, + qsc_fir::fir::PackageId, + qsc_fir::assigner::Assigner, +); + +/// Compiles a Q# source string as an executable on top of core+std. +fn compile_and_lower(source: &str) -> LoweredOutput { + let (store, package_id) = compile_to_fir(source); + let assigner = qsc_fir::assigner::Assigner::from_package(store.get(package_id)); + (store, package_id, assigner) +} + +fn format_pipeline_errors(errors: &[PipelineError]) -> String { + if errors.is_empty() { + "(no error)".to_string() + } else { + errors + .iter() + .map(ToString::to_string) + .collect::>() + .join("\n") + } +} + +fn run_pipeline_successfully( + store: &mut qsc_fir::fir::PackageStore, + pkg_id: qsc_fir::fir::PackageId, +) { + let errors = run_pipeline(store, pkg_id); + assert_no_pipeline_errors("run_pipeline", &errors); +} + +fn run_pipeline_to_successfully( + store: &mut qsc_fir::fir::PackageStore, + pkg_id: qsc_fir::fir::PackageId, + stage: PipelineStage, +) { + let errors = run_pipeline_to(store, pkg_id, stage, &[]); + assert_no_pipeline_errors("run_pipeline_to", &errors); +} + +fn callable_body_spec<'a>( + decl: &'a qsc_fir::fir::CallableDecl, + callable_name: &str, +) -> &'a qsc_fir::fir::SpecDecl { + match &decl.implementation { + CallableImpl::Spec(spec_impl) => &spec_impl.body, + CallableImpl::SimulatableIntrinsic(spec) => spec, + CallableImpl::Intrinsic => panic!("callable '{callable_name}' should have a body"), + } +} + +fn reachable_callable_names( + store: &qsc_fir::fir::PackageStore, + pkg_id: qsc_fir::fir::PackageId, +) -> Vec { + let package = store.get(pkg_id); + let reachable = reachability::collect_reachable_from_entry(store, pkg_id); + + let mut names = Vec::new(); + for store_id in &reachable { + if store_id.package != pkg_id { + continue; + } + let item = package.get_item(store_id.item); + if let ItemKind::Callable(decl) = &item.kind { + names.push(decl.name.name.to_string()); + } + } + names.sort(); + names +} + +fn reachable_callable_summary( + store: &qsc_fir::fir::PackageStore, + pkg_id: qsc_fir::fir::PackageId, +) -> String { + let package = store.get(pkg_id); + let reachable = reachability::collect_reachable_from_entry(store, pkg_id); + + let mut lines = Vec::new(); + for store_id in &reachable { + if store_id.package != pkg_id { + continue; + } + let item = package.get_item(store_id.item); + if let ItemKind::Callable(decl) = &item.kind { + let pat = package.get_pat(decl.input); + lines.push(format!( + "{}: input_ty={}, output_ty={}", + decl.name.name, pat.ty, decl.output + )); + } + } + lines.sort(); + lines.join("\n") +} + +fn callable_body_summary( + store: &qsc_fir::fir::PackageStore, + pkg_id: qsc_fir::fir::PackageId, + callable_name: &str, +) -> String { + let package = store.get(pkg_id); + let item = package + .items + .values() + .find(|item| match &item.kind { + ItemKind::Callable(decl) => decl.name.name.as_ref() == callable_name, + _ => false, + }) + .expect("callable should exist"); + + let ItemKind::Callable(decl) = &item.kind else { + panic!("item should be callable"); + }; + let spec = callable_body_spec(decl, callable_name); + let block = package.get_block(spec.block); + + let mut lines = vec![format!("block_ty={}", block.ty)]; + for (index, stmt_id) in block.stmts.iter().enumerate() { + let stmt = package.get_stmt(*stmt_id); + let line = match &stmt.kind { + qsc_fir::fir::StmtKind::Expr(expr_id) => { + let expr = package.get_expr(*expr_id); + format!( + "[{index}] Expr ty={} {}", + expr.ty, + expr_kind_short(package, *expr_id) + ) + } + qsc_fir::fir::StmtKind::Semi(expr_id) => { + let expr = package.get_expr(*expr_id); + format!( + "[{index}] Semi ty={} {}", + expr.ty, + expr_kind_short(package, *expr_id) + ) + } + qsc_fir::fir::StmtKind::Local(_, pat_id, expr_id) => { + let pat = package.get_pat(*pat_id); + let expr = package.get_expr(*expr_id); + format!( + "[{index}] Local pat_ty={} init_ty={} {}", + pat.ty, + expr.ty, + expr_kind_short(package, *expr_id) + ) + } + qsc_fir::fir::StmtKind::Item(local_item_id) => { + format!("[{index}] Item {local_item_id}") + } + }; + lines.push(line); + } + + lines.join("\n") +} + +fn package_has_callable_named( + store: &qsc_fir::fir::PackageStore, + pkg_id: qsc_fir::fir::PackageId, + callable_name: &str, +) -> bool { + let package = store.get(pkg_id); + package.items.values().any(|item| match &item.kind { + ItemKind::Callable(decl) => decl.name.name.as_ref() == callable_name, + _ => false, + }) +} + +fn expr_targets_callable( + package: &qsc_fir::fir::Package, + pkg_id: qsc_fir::fir::PackageId, + expr_id: qsc_fir::fir::ExprId, + callable_name: &str, +) -> bool { + let expr = package.get_expr(expr_id); + match &expr.kind { + ExprKind::Var(qsc_fir::fir::Res::Item(item_id), _) + if item_id.package == pkg_id + && matches!( + &package.get_item(item_id.item).kind, + ItemKind::Callable(decl) if decl.name.name.as_ref() == callable_name + ) => + { + true + } + ExprKind::UnOp(_, inner_id) => { + expr_targets_callable(package, pkg_id, *inner_id, callable_name) + } + _ => false, + } +} + +#[test] +fn post_arg_promote_cut_matches_full_pipeline_bodies() { + let source = r#" + operation Identity<'T>(x : 'T) : 'T { x } + operation Apply(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + @EntryPoint() + operation Main() : Int { + use q = Qubit(); + let angle = 1.0; + Apply(q1 => Rx(angle, q1), q); + let pair = Identity((M(q), 7)); + Reset(q); + let (_, value) = pair; + value + } + "#; + + let (mut post_arg_store, post_arg_pkg_id, _) = compile_and_lower(source); + let (mut full_store, full_pkg_id, _) = compile_and_lower(source); + + run_pipeline_to_successfully( + &mut post_arg_store, + post_arg_pkg_id, + PipelineStage::ArgPromote, + ); + run_pipeline_successfully(&mut full_store, full_pkg_id); + + invariants::check( + &post_arg_store, + post_arg_pkg_id, + invariants::InvariantLevel::PostArgPromote, + ); + + let full_package = full_store.get(full_pkg_id); + validate(full_package, &full_store); + + assert_eq!( + reachable_callable_summary(&post_arg_store, post_arg_pkg_id), + reachable_callable_summary(&full_store, full_pkg_id) + ); + + let post_arg_callables = reachable_callable_names(&post_arg_store, post_arg_pkg_id); + let full_callables = reachable_callable_names(&full_store, full_pkg_id); + assert_eq!(post_arg_callables, full_callables); + + for callable_name in &full_callables { + assert_eq!( + callable_body_summary(&post_arg_store, post_arg_pkg_id, callable_name), + callable_body_summary(&full_store, full_pkg_id, callable_name), + "callable '{callable_name}' body drift between PostArgPromote and full pipeline" + ); + } +} + +#[test] +fn terminal_result_block_shape_stays_valid_across_stage_boundaries() { + let source = r#" + namespace Test { + @EntryPoint() + operation Main() : Result { + use q = Qubit(); + let r = M(q); + Reset(q); + return r; + } + } + "#; + + let (store, pkg_id, _) = compile_and_lower(source); + let (mut post_return_store, post_return_pkg_id, _) = compile_and_lower(source); + let (mut post_all_store, post_all_pkg_id, _) = compile_and_lower(source); + + let mut snapshots = vec![format!( + "Lowered\n{}", + callable_body_summary(&store, pkg_id, "Main") + )]; + + run_pipeline_to_successfully( + &mut post_return_store, + post_return_pkg_id, + PipelineStage::ReturnUnify, + ); + snapshots.push(format!( + "PostReturnUnify\n{}", + callable_body_summary(&post_return_store, post_return_pkg_id, "Main") + )); + assert_callable_body_terminal_expr_matches_block_type( + &post_return_store, + post_return_pkg_id, + "Main", + ); + + run_pipeline_to_successfully(&mut post_all_store, post_all_pkg_id, PipelineStage::Full); + snapshots.push(format!( + "PostAll\n{}", + callable_body_summary(&post_all_store, post_all_pkg_id, "Main") + )); + assert_callable_body_terminal_expr_matches_block_type(&post_all_store, post_all_pkg_id, "Main"); + + let expected = concat!( + "Lowered\n", + "block_ty=Result\n", + "[0] Local pat_ty=Qubit init_ty=Qubit Call\n", + "[1] Local pat_ty=Result init_ty=Result Call\n", + "[2] Semi ty=Unit Call\n", + "[3] Semi ty=Unit Block\n", + "[4] Semi ty=Unit Call\n", + "\n", + "PostReturnUnify\n", + "block_ty=Result\n", + "[0] Local pat_ty=Qubit init_ty=Qubit Call\n", + "[1] Local pat_ty=Result init_ty=Result Call\n", + "[2] Semi ty=Unit Call\n", + "[3] Expr ty=Result Block\n\n", + "PostAll\n", + "block_ty=Result\n", + "[0] Local pat_ty=Qubit init_ty=Qubit Call\n", + "[1] Local pat_ty=Result init_ty=Result Call\n", + "[2] Semi ty=Unit Call\n", + "[3] Expr ty=Result Block" + ); + assert_eq!(snapshots.join("\n\n"), expected); +} + +#[test] +fn terminal_result_array_block_shape_through_use_scope_stays_valid() { + let source = r#" + namespace Test { + @EntryPoint() + operation SearchForMarkedInput() : Result[] { + let nQubits = 2; + use qubits = Qubit[nQubits] { + return MResetEachZ(qubits); + } + } + } + "#; + let (mut post_return_store, post_return_pkg_id, _) = compile_and_lower(source); + let (mut post_all_store, post_all_pkg_id, _) = compile_and_lower(source); + + run_pipeline_to_successfully( + &mut post_return_store, + post_return_pkg_id, + PipelineStage::ReturnUnify, + ); + assert_callable_body_terminal_expr_matches_block_type( + &post_return_store, + post_return_pkg_id, + "SearchForMarkedInput", + ); + + run_pipeline_to_successfully(&mut post_all_store, post_all_pkg_id, PipelineStage::Full); + assert_callable_body_terminal_expr_matches_block_type( + &post_all_store, + post_all_pkg_id, + "SearchForMarkedInput", + ); +} + +#[test] +fn simple_entry_point_passes_all_invariants() { + let (mut fir_store, fir_pkg_id, _) = compile_and_lower("operation Main() : Int { 42 }"); + run_pipeline_successfully(&mut fir_store, fir_pkg_id); + let package = fir_store.get(fir_pkg_id); + validate(package, &fir_store); +} + +#[test] +fn generic_identity_monomorphized_to_concrete_type() { + let (mut fir_store, fir_pkg_id, _) = compile_and_lower( + r#" + operation Identity<'T>(x : 'T) : 'T { x } + operation Main() : Int { Identity(42) } + "#, + ); + run_pipeline_successfully(&mut fir_store, fir_pkg_id); + let package = fir_store.get(fir_pkg_id); + validate(package, &fir_store); + invariants::check(&fir_store, fir_pkg_id, invariants::InvariantLevel::PostAll); +} + +#[test] +fn qubit_allocation_preserved_through_pipeline() { + let (mut fir_store, fir_pkg_id, _) = compile_and_lower( + r#" + operation Main() : Result { + use q = Qubit(); + H(q); + let r = M(q); + Reset(q); + r + } + "#, + ); + run_pipeline_successfully(&mut fir_store, fir_pkg_id); + let package = fir_store.get(fir_pkg_id); + validate(package, &fir_store); +} + +#[test] +fn callable_argument_defunctionalized_to_direct_call() { + let (mut fir_store, fir_pkg_id, _) = compile_and_lower( + r#" + operation Apply(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation Main() : Unit { + use q = Qubit(); + Apply(H, q); + Reset(q); + } + "#, + ); + run_pipeline_successfully(&mut fir_store, fir_pkg_id); + let package = fir_store.get(fir_pkg_id); + validate(package, &fir_store); +} + +#[test] +fn tuple_return_scalars_promoted_by_sroa() { + let (mut fir_store, fir_pkg_id, _) = compile_and_lower( + r#" + operation Pair() : (Int, Bool) { (1, true) } + operation Main() : Int { + let (a, _) = Pair(); + a + } + "#, + ); + run_pipeline_successfully(&mut fir_store, fir_pkg_id); + let package = fir_store.get(fir_pkg_id); + validate(package, &fir_store); +} + +#[test] +fn for_loop_iterators_pass_invariants() { + let (mut fir_store, fir_pkg_id, _) = compile_and_lower( + r#" + operation Main() : Int { + mutable sum = 0; + for i in 0..4 { + sum += i; + } + sum + } + "#, + ); + run_pipeline_successfully(&mut fir_store, fir_pkg_id); + let package = fir_store.get(fir_pkg_id); + validate(package, &fir_store); +} + +#[test] +fn array_operations_pass_post_pipeline_invariants() { + let (mut fir_store, fir_pkg_id, _) = compile_and_lower( + r#" + operation Main() : Int { + let arr = [1, 2, 3]; + arr[1] + } + "#, + ); + run_pipeline_successfully(&mut fir_store, fir_pkg_id); + let package = fir_store.get(fir_pkg_id); + validate(package, &fir_store); +} + +#[test] +fn composite_while_return_survives_full_pipeline() { + let (mut fir_store, fir_pkg_id, _) = compile_and_lower( + r#" + namespace Test { + struct Pair { + Left : Int, + Right : Bool + } + + function Helper() : Pair { + mutable i = 0; + while i < 3 { + if i == 1 { + return new Pair { Left = i, Right = true }; + } + i += 1; + } + new Pair { Left = -1, Right = false } + } + + @EntryPoint() + operation Main() : Int { + let _ = Helper(); + 0 + } + } + "#, + ); + + let errors = run_pipeline(&mut fir_store, fir_pkg_id); + assert_no_pipeline_errors("run_pipeline", &errors); + + let package = fir_store.get(fir_pkg_id); + validate(package, &fir_store); + invariants::check(&fir_store, fir_pkg_id, invariants::InvariantLevel::PostAll); +} + +#[test] +fn run_pipeline_returns_dynamic_callable_defunctionalization_diagnostics() { + let (mut fir_store, fir_pkg_id, _) = compile_and_lower( + r#" + operation ApplyOp(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + + @EntryPoint() + operation Main() : Unit { + use q = Qubit(); + mutable op = H; + for _ in 0..3 { + op = X; + } + ApplyOp(op, q); + } + "#, + ); + + let errors = run_pipeline(&mut fir_store, fir_pkg_id); + + assert_eq!( + errors.len(), + 1, + "expected one defunctionalization diagnostic, got:\n{}", + format_pipeline_errors(&errors) + ); + assert!( + matches!( + errors.as_slice(), + [PipelineError::Defunctionalize( + qsc_fir_transforms::defunctionalize::Error::DynamicCallable(_) + )] + ), + "expected a DynamicCallable diagnostic, got:\n{}", + format_pipeline_errors(&errors) + ); + assert_eq!( + errors[0].to_string(), + "callable argument could not be resolved statically" + ); +} + +#[test] +fn apply_operation_power_a_library_repro_trips_local_var_consistency() { + let (mut fir_store, fir_pkg_id, _) = compile_and_lower( + r#" + @EntryPoint() + operation Main() : Result { + use q = Qubit(); + ApplyOperationPowerA(12, Rx(Std.Math.PI()/16.0, _), q); + ApplyOperationPowerA(-3, Rx(Std.Math.PI()/4.0, _), q); + M(q) + } + "#, + ); + + run_pipeline_successfully(&mut fir_store, fir_pkg_id); + + let package = fir_store.get(fir_pkg_id); + validate(package, &fir_store); + invariants::check(&fir_store, fir_pkg_id, invariants::InvariantLevel::PostAll); +} + +#[test] +fn apply_operation_power_ca_library_repro_preserves_local_var_consistency() { + let (mut fir_store, fir_pkg_id, _) = compile_and_lower( + r#" + operation Consume(apply_power_of_u : (Int, Qubit[]) => Unit is Adj + Ctl, target : Qubit[]) : Result { + apply_power_of_u(1, target); + M(target[0]) + } + + operation U(qs : Qubit[]) : Unit is Adj + Ctl { + H(qs[0]); + } + + @EntryPoint() + operation Main() : Result { + use qs = Qubit[1]; + Consume(ApplyOperationPowerCA(_, U, _), qs) + } + "#, + ); + + run_pipeline_successfully(&mut fir_store, fir_pkg_id); + + let package = fir_store.get(fir_pkg_id); + validate(package, &fir_store); + invariants::check(&fir_store, fir_pkg_id, invariants::InvariantLevel::PostAll); +} + +#[test] +fn apply_operation_power_ca_array_lambda_preserves_call_shape() { + let (mut fir_store, fir_pkg_id, _) = compile_and_lower( + r#" + operation Main() : Unit { + use state = Qubit(); + use phase = Qubit[2]; + let oracle = ApplyOperationPowerCA(_, qs => U(qs[0]), _); + ApplyQPE(oracle, [state], phase); + } + + operation U(q : Qubit) : Unit is Ctl + Adj { + Rz(Std.Math.PI() / 3.0, q); + } + "#, + ); + + run_pipeline_successfully(&mut fir_store, fir_pkg_id); + + let package = fir_store.get(fir_pkg_id); + validate(package, &fir_store); + invariants::check(&fir_store, fir_pkg_id, invariants::InvariantLevel::PostAll); +} + +#[test] +fn pipeline_preserves_entry_expression() { + let (mut fir_store, fir_pkg_id, _) = compile_and_lower("operation Main() : Int { 99 }"); + run_pipeline_successfully(&mut fir_store, fir_pkg_id); + let package = fir_store.get(fir_pkg_id); + assert!( + package.entry.is_some(), + "entry expression must still exist after pipeline" + ); +} + +#[test] +fn nested_generics_fully_monomorphized() { + let (mut fir_store, fir_pkg_id, _) = compile_and_lower( + r#" + function Inner<'T>(x : 'T) : 'T { x } + function Outer<'T>(x : 'T) : 'T { Inner(x) } + @EntryPoint() + operation Main() : Unit { let _ = Outer(42); } + "#, + ); + run_pipeline_successfully(&mut fir_store, fir_pkg_id); + let package = fir_store.get(fir_pkg_id); + validate(package, &fir_store); + invariants::check(&fir_store, fir_pkg_id, invariants::InvariantLevel::PostAll); +} + +#[test] +fn generic_for_loop_monomorphized_and_invariants_hold() { + let (mut fir_store, fir_pkg_id, _) = compile_and_lower( + r#" + operation Apply<'T>(op : ('T => Unit), items : 'T[]) : Unit { + for item in items { op(item); } + } + @EntryPoint() + operation Main() : Unit { + use qs = Qubit[3]; + Apply(H, qs); + } + "#, + ); + run_pipeline_successfully(&mut fir_store, fir_pkg_id); + let package = fir_store.get(fir_pkg_id); + validate(package, &fir_store); + invariants::check(&fir_store, fir_pkg_id, invariants::InvariantLevel::PostAll); +} + +#[test] +fn cross_package_apply_to_each_inlined_and_valid() { + let (mut fir_store, fir_pkg_id, _) = compile_and_lower( + r#" + open Std.Canon; + @EntryPoint() + operation Main() : Unit { + use qs = Qubit[3]; + ApplyToEach(H, qs); + } + "#, + ); + run_pipeline_successfully(&mut fir_store, fir_pkg_id); + let package = fir_store.get(fir_pkg_id); + validate(package, &fir_store); + invariants::check(&fir_store, fir_pkg_id, invariants::InvariantLevel::PostAll); +} + +#[test] +fn multiple_generic_instantiations_each_specialized() { + let (mut fir_store, fir_pkg_id, _) = compile_and_lower( + r#" + function Identity<'T>(x : 'T) : 'T { x } + @EntryPoint() + operation Main() : Unit { + let a = Identity(42); + let b = Identity(1.0); + } + "#, + ); + run_pipeline_successfully(&mut fir_store, fir_pkg_id); + let package = fir_store.get(fir_pkg_id); + validate(package, &fir_store); + invariants::check(&fir_store, fir_pkg_id, invariants::InvariantLevel::PostAll); +} + +#[test] +fn cross_package_nested_generics_fully_resolved() { + // Uses Std.Arrays.Mapped (generic) which internally calls other std + // generic helpers. This exercises the cross-package nested-generic + // worklist: cloning Mapped into user package discovers further + // cross-package generic references that must also be specialized. + let (mut fir_store, fir_pkg_id, _) = compile_and_lower( + r#" + open Std.Arrays; + function PlusOne(x : Int) : Int { x + 1 } + @EntryPoint() + operation Main() : Unit { + let arr = [1, 2, 3]; + let mapped = Mapped(PlusOne, arr); + } + "#, + ); + run_pipeline_successfully(&mut fir_store, fir_pkg_id); + let package = fir_store.get(fir_pkg_id); + validate(package, &fir_store); + invariants::check(&fir_store, fir_pkg_id, invariants::InvariantLevel::PostAll); +} + +#[test] +fn closure_specialization_preserves_lambda_tuple_call_shape() { + let (mut fir_store, fir_pkg_id, _) = compile_and_lower( + r#" + @EntryPoint() + operation Main() : (Int, Bool)[] { + Microsoft.Quantum.Arrays.Enumerated([true, false]) + } + "#, + ); + + run_pipeline_to_successfully(&mut fir_store, fir_pkg_id, PipelineStage::Full); + + let package = fir_store.get(fir_pkg_id); + let mapper = package + .items + .values() + .find_map(|item| match &item.kind { + ItemKind::Callable(decl) + if decl + .name + .name + .as_ref() + .starts_with("MappedByIndex") => + { + Some(decl.as_ref()) + } + _ => None, + }) + .unwrap_or_else(|| { + panic!( + "MappedByIndex specialization should exist\n{}", + reachable_callable_summary(&fir_store, fir_pkg_id) + ) + }); + + let lambda_names = package + .items + .values() + .filter_map(|item| match &item.kind { + ItemKind::Callable(decl) if decl.name.name.as_ref().starts_with("") => { + Some(decl.name.name.to_string()) + } + _ => None, + }) + .collect::>(); + let args_id = package + .exprs + .values() + .find_map(|expr| match &expr.kind { + ExprKind::Call(callee_id, args_id) + if expr_targets_callable(package, fir_pkg_id, *callee_id, "") => + { + Some(*args_id) + } + _ => None, + }) + .unwrap_or_else(|| { + panic!( + "specialized mapper body should call the lifted lambda directly\nmapper body:\n{}\nlambdas:\n{}", + callable_body_summary( + &fir_store, + fir_pkg_id, + mapper.name.name.as_ref(), + ), + lambda_names.join("\n") + ) + }); + + let args_expr = package.get_expr(args_id); + assert_eq!( + args_expr.ty.to_string(), + "((Int, Bool),)", + "direct lambda calls should preserve closure-style argument packaging" + ); + + let ExprKind::Tuple(args_items) = &args_expr.kind else { + panic!("direct lambda call should package its argument as a one-element tuple"); + }; + assert_eq!( + args_items.len(), + 1, + "lambda call should have exactly one packaged argument" + ); + + let inner_expr = package.get_expr(args_items[0]); + assert_eq!(inner_expr.ty.to_string(), "(Int, Bool)"); + assert!( + matches!(&inner_expr.kind, ExprKind::Tuple(items) if items.len() == 2), + "inner packaged lambda argument should remain the original pair" + ); + + validate(package, &fir_store); + invariants::check(&fir_store, fir_pkg_id, invariants::InvariantLevel::PostAll); +} + +#[test] +fn direct_lambda_calls_preserve_nested_tuple_packaging() { + let (mut fir_store, fir_pkg_id, _) = compile_and_lower( + r#" + @EntryPoint() + operation Main() : Int { + let add = (x, y) -> x + y; + add(2, 3) + } + "#, + ); + + run_pipeline_to_successfully(&mut fir_store, fir_pkg_id, PipelineStage::Full); + + let package = fir_store.get(fir_pkg_id); + let lambda_names = package + .items + .values() + .filter_map(|item| match &item.kind { + ItemKind::Callable(decl) if decl.name.name.as_ref().starts_with("") => { + Some(decl.name.name.to_string()) + } + _ => None, + }) + .collect::>(); + let args_id = package + .exprs + .values() + .find_map(|expr| match &expr.kind { + ExprKind::Call(callee_id, args_id) + if expr_targets_callable(package, fir_pkg_id, *callee_id, "") => + { + Some(*args_id) + } + _ => None, + }) + .unwrap_or_else(|| { + panic!( + "Main should call the lifted lambda directly\nMain body:\n{}\nlambdas:\n{}", + callable_body_summary(&fir_store, fir_pkg_id, "Main"), + lambda_names.join("\n") + ) + }); + + let args_expr = package.get_expr(args_id); + assert_eq!( + args_expr.ty.to_string(), + "((Int, Int),)", + "direct lambda calls should preserve the original tuple argument as one packaged value" + ); + + let ExprKind::Tuple(args_items) = &args_expr.kind else { + panic!("direct lambda call should package its argument as a one-element tuple"); + }; + assert_eq!( + args_items.len(), + 1, + "lambda call should have exactly one packaged argument" + ); + + let inner_expr = package.get_expr(args_items[0]); + assert_eq!(inner_expr.ty.to_string(), "(Int, Int)"); + assert!( + matches!(&inner_expr.kind, ExprKind::Tuple(items) if items.len() == 2), + "inner packaged lambda argument should remain the original pair" + ); + + validate(package, &fir_store); + invariants::check(&fir_store, fir_pkg_id, invariants::InvariantLevel::PostAll); +} + +#[test] +fn entry_expression_simple_call_passes_pipeline() { + let source = r#" + namespace Test { + operation Greet() : Result { + use q = Qubit(); + H(q); + M(q) + } + } + "#; + let (mut store, pkg_id) = compile_to_fir_with_entry(source, "Test.Greet()"); + run_pipeline_successfully(&mut store, pkg_id); + let package = store.get(pkg_id); + validate(package, &store); +} + +#[test] +fn entry_expression_with_callable_arg_passes_pipeline() { + let source = r#" + namespace Test { + operation Apply(op : Qubit => Unit, q : Qubit) : Unit { + op(q); + } + operation Run() : Result { + use q = Qubit(); + Apply(H, q); + M(q) + } + } + "#; + let (mut store, pkg_id) = compile_to_fir_with_entry(source, "Test.Run()"); + run_pipeline_successfully(&mut store, pkg_id); + let package = store.get(pkg_id); + validate(package, &store); +} + +#[test] +fn multi_arrow_multi_level_hof_passes_pipeline() { + let source = r#" + namespace Test { + operation ApplyBoth(f : Qubit => Unit, g : Qubit => Unit, q : Qubit) : Unit { + f(q); + g(q); + } + + operation Compose( + inner : ((Qubit => Unit, Qubit => Unit, Qubit) => Unit), + f : Qubit => Unit, + g : Qubit => Unit, + q : Qubit + ) : Unit { + inner(f, g, q); + } + + @EntryPoint() + operation Main() : Result { + use q = Qubit(); + Compose(ApplyBoth, H, X, q); + M(q) + } + } + "#; + let (mut store, pkg_id) = compile_to_fir(source); + run_pipeline_successfully(&mut store, pkg_id); + let package = store.get(pkg_id); + validate(package, &store); + invariants::check(&store, pkg_id, invariants::InvariantLevel::PostAll); +} + +/// Exercises a UDT that wraps a callable type through the full pipeline. +/// +/// The callable-wrapping UDT is constructed and unwrapped locally without +/// appearing in callable signatures, so defunctionalization can eliminate +/// the inner arrow type without requiring parameter-level changes. This +/// confirms that `defunctionalize` safely handles `Ty::Udt` nodes that +/// contain callable fields. +#[test] +fn udt_wrapping_callable_survives_full_pipeline() { + let source = r#" + namespace Test { + newtype MyOp = (Qubit => Unit); + + @EntryPoint() + operation Main() : Unit { + let wrapped = MyOp(q => H(q)); + use q = Qubit(); + (wrapped!)(q); + Reset(q); + } + } + "#; + let (mut store, pkg_id) = compile_to_fir(source); + run_pipeline_successfully(&mut store, pkg_id); + let package = store.get(pkg_id); + validate(package, &store); +} + +/// Exercises a cross-package UDT constructor through the full pipeline. +/// Uses the `Complex` struct from the core library, which is exported and +/// available to user code. +/// +/// NOTE: The Q# frontend resolver fails to resolve cross-package UDT +/// constructors in expression position, producing `Res::Err` / `Ty::Err` +/// before any pipeline transforms run. See `qsc_frontend/src/lower.rs` +/// line 1059 for the `hir::Res::Err` fallback. This is a frontend bug, +/// not a pipeline bug. +#[test] +fn cross_package_udt_constructor_resolution() { + let source = r#" + @EntryPoint() + operation Main() : Int { + let c = Complex(1.0, 2.0); + 0 + } + "#; + let (mut store, pkg_id) = compile_to_fir(source); + run_pipeline_successfully(&mut store, pkg_id); + let package = store.get(pkg_id); + validate(package, &store); +} + +/// Local multi-field UDT with a callable field that is never invoked. +/// UDT erasure exposes the arrow type inside the tuple; the invariant +/// must tolerate this between UDT erasure and SROA. +#[test] +fn local_multi_field_udt_callable_never_invoked() { + let source = r#" + namespace Test { + newtype Config = (Count: Int, Op: Qubit[] => Unit is Adj); + operation NoOp(qs : Qubit[]) : Unit is Adj {} + @EntryPoint() + operation Main() : Int { let cfg = Config(0, NoOp); 0 } + } + "#; + let (mut store, pkg_id) = compile_to_fir(source); + run_pipeline_successfully(&mut store, pkg_id); + let package = store.get(pkg_id); + validate(package, &store); +} + +/// Local multi-field UDT with a callable field extracted via field accessor +/// and invoked. Confirms that defunc and UDT erasure cooperate correctly +/// when the callable is actually called. +#[test] +fn local_multi_field_udt_callable_field_invoked() { + let source = r#" + namespace Test { + newtype Config = (Count: Int, Op: Qubit[] => Unit is Adj); + operation NoOp(qs : Qubit[]) : Unit is Adj {} + @EntryPoint() + operation Main() : Unit { + let cfg = Config(0, NoOp); + use qs = Qubit[cfg::Count]; + cfg::Op(qs); + } + } + "#; + let (mut store, pkg_id) = compile_to_fir(source); + run_pipeline_successfully(&mut store, pkg_id); + let package = store.get(pkg_id); + validate(package, &store); +} + +/// Local multi-field UDT with a callable field passed to a higher-order +/// function. Exercises defunc's expression-level analysis when the arrow +/// value flows through a HOF call site. +#[test] +#[ignore = "defunc limitation: callable extracted via UDT field accessor (w::F) cannot be statically resolved when passed to a HOF"] +fn local_multi_field_udt_callable_passed_to_hof() { + let source = r#" + namespace Test { + newtype Wrapper = (Count: Int, F: Int -> Int); + function Inc(x: Int) : Int { x + 1 } + function Apply(f: Int -> Int, x: Int) : Int { f(x) } + @EntryPoint() + operation Main() : Int { + let w = Wrapper(0, Inc); + Apply(w::F, 5) + } + } + "#; + let (mut store, pkg_id) = compile_to_fir(source); + run_pipeline_successfully(&mut store, pkg_id); + let package = store.get(pkg_id); + validate(package, &store); +} + +// ============================================================================ +// Stage-Parity Integration Tests +// ============================================================================ +// These tests verify that FIR output at each pipeline stage is parity with +// the full pipeline. Stage-parity ensures that: +// +// 1. Callable count remains consistent (callables are not unexpectedly added/removed) +// 2. Statement IDs are valid references (no dangling refs to removed items) +// 3. Executable graph is well-formed or empty as expected +// 4. Type correctness is preserved across the stage boundary +// 5. Package structure and export lists remain consistent + +#[test] +fn stage_parity_mono_monomorphization_preserves_callable_types() { + // Stage-parity check after monomorphization. + // + // Invariant: After Mono, all generic parameters are erased and concrete + // monomorphized callables exist. Callable count should match full pipeline + // (Mono doesn't create or remove callables; it specializes them). + // + // Importance: Mono is the first transformation. Validating its output + // parity ensures subsequent passes inherit a well-formed FIR with no + // unexpected callable additions or deletions. + let source = r#" + function Identity<'T>(x : 'T) : 'T { x } + @EntryPoint() + operation Main() : Int { + let a = Identity(42); + let b = Identity(1.5); + a + } + "#; + + let (mut post_mono_store, post_mono_pkg_id, _) = compile_and_lower(source); + let (mut full_store, full_pkg_id, _) = compile_and_lower(source); + + run_pipeline_to_successfully(&mut post_mono_store, post_mono_pkg_id, PipelineStage::Mono); + run_pipeline_successfully(&mut full_store, full_pkg_id); + + invariants::check( + &post_mono_store, + post_mono_pkg_id, + invariants::InvariantLevel::PostMono, + ); + + let full_package = full_store.get(full_pkg_id); + validate(full_package, &full_store); + + // Callable count parity: Mono should not add/remove callables. + let post_mono_callables = reachable_callable_names(&post_mono_store, post_mono_pkg_id); + let full_callables = reachable_callable_names(&full_store, full_pkg_id); + assert_eq!( + post_mono_callables, full_callables, + "callable set must be identical after Mono and full pipeline" + ); + + // Type consistency: callable signatures should be identical. + assert_eq!( + reachable_callable_summary(&post_mono_store, post_mono_pkg_id), + reachable_callable_summary(&full_store, full_pkg_id) + ); +} + +#[test] +fn stage_parity_defunc_defunctionalization_eliminates_callable_types() { + // Stage-parity check after defunctionalization. + // + // Invariant: After Defunc, all arrow types and closure expressions + // have been eliminated from reachable code. Callable-wrapping closures + // are lifted to callable declarations, but the count in reachable code + // should match the full pipeline (lifted callables participate in + // reachability from Main). + // + // Importance: Defunc is a high-value transformation that changes the + // structure of the FIR significantly. Validating parity ensures that + // callable creation during lifting does not introduce duplicate or + // stray callables, and that the reachable set is stable. + let source = r#" + operation Apply(op : Qubit => Unit, q : Qubit) : Unit { op(q); } + @EntryPoint() + operation Main() : Unit { + use q = Qubit(); + Apply(H, q); + Reset(q); + } + "#; + + let (mut post_defunc_store, post_defunc_pkg_id, _) = compile_and_lower(source); + let (mut full_store, full_pkg_id, _) = compile_and_lower(source); + + run_pipeline_to_successfully( + &mut post_defunc_store, + post_defunc_pkg_id, + PipelineStage::Defunc, + ); + run_pipeline_successfully(&mut full_store, full_pkg_id); + + invariants::check( + &post_defunc_store, + post_defunc_pkg_id, + invariants::InvariantLevel::PostDefunc, + ); + + let full_package = full_store.get(full_pkg_id); + validate(full_package, &full_store); + + // Callable count parity: lifted callables should be in reachable set. + let post_defunc_callables = reachable_callable_names(&post_defunc_store, post_defunc_pkg_id); + let full_callables = reachable_callable_names(&full_store, full_pkg_id); + assert_eq!( + post_defunc_callables, full_callables, + "callable set after Defunc must match full pipeline" + ); + + // Type summary parity. + assert_eq!( + reachable_callable_summary(&post_defunc_store, post_defunc_pkg_id), + reachable_callable_summary(&full_store, full_pkg_id) + ); +} + +#[test] +fn stage_parity_udt_erase_eliminates_udt_types() { + // Stage-parity check after UDT erasure. + // + // Invariant: After UdtErase, all Ty::Udt types are erased from + // reachable code. UDT-wrapping callables are eliminated only if + // they become unreachable (deferred to item_dce). Reachable callables + // should match the full pipeline output. + // + // Importance: UDT erasure is a significant structural transformation + // that rewrites type signatures. Validating parity ensures no callables + // are unexpectedly preserved or removed at this stage. + let source = r#" + namespace Test { + newtype Wrapper = (x: Int); + function Extract(w: Wrapper) : Int { w::x } + @EntryPoint() + operation Main() : Int { + let w = Wrapper(42); + Extract(w) + } + } + "#; + + let (mut post_udt_store, post_udt_pkg_id, _) = compile_and_lower(source); + let (mut full_store, full_pkg_id, _) = compile_and_lower(source); + + run_pipeline_to_successfully( + &mut post_udt_store, + post_udt_pkg_id, + PipelineStage::UdtErase, + ); + run_pipeline_successfully(&mut full_store, full_pkg_id); + + invariants::check( + &post_udt_store, + post_udt_pkg_id, + invariants::InvariantLevel::PostUdtErase, + ); + + let full_package = full_store.get(full_pkg_id); + validate(full_package, &full_store); + + // Callable count parity. + let post_udt_callables = reachable_callable_names(&post_udt_store, post_udt_pkg_id); + let full_callables = reachable_callable_names(&full_store, full_pkg_id); + assert_eq!( + post_udt_callables, full_callables, + "callable set after UdtErase must match full pipeline" + ); + + // Type summary parity. + assert_eq!( + reachable_callable_summary(&post_udt_store, post_udt_pkg_id), + reachable_callable_summary(&full_store, full_pkg_id) + ); +} + +#[test] +fn stage_parity_tuple_comp_lower_lowers_tuple_equality() { + // Stage-parity check after tuple comparison lowering. + // + // Invariant: After TupleCompLower, all tuple equality and inequality + // operations are lowered to scalar comparisons and logical operators. + // No BinOp(Eq/Neq) with tuple operands should exist. Callable count + // should match full pipeline. + // + // Importance: TupleCompLower is a mid-pipeline pass that preserves the + // callable set while rewriting expression structure. Validating parity + // ensures no unexpected side effects on the callable structure. + let source = r#" + @EntryPoint() + operation Main() : Bool { + let pair1 = (1, 2); + let pair2 = (1, 2); + pair1 == pair2 + } + "#; + + let (mut post_tuple_store, post_tuple_pkg_id, _) = compile_and_lower(source); + let (mut full_store, full_pkg_id, _) = compile_and_lower(source); + + run_pipeline_to_successfully( + &mut post_tuple_store, + post_tuple_pkg_id, + PipelineStage::TupleCompLower, + ); + run_pipeline_successfully(&mut full_store, full_pkg_id); + + invariants::check( + &post_tuple_store, + post_tuple_pkg_id, + invariants::InvariantLevel::PostTupleCompLower, + ); + + let full_package = full_store.get(full_pkg_id); + validate(full_package, &full_store); + + // Callable count parity. + let post_tuple_callables = reachable_callable_names(&post_tuple_store, post_tuple_pkg_id); + let full_callables = reachable_callable_names(&full_store, full_pkg_id); + assert_eq!( + post_tuple_callables, full_callables, + "callable set after TupleCompLower must match full pipeline" + ); + + // Type summary parity. + assert_eq!( + reachable_callable_summary(&post_tuple_store, post_tuple_pkg_id), + reachable_callable_summary(&full_store, full_pkg_id) + ); +} + +#[test] +fn stage_parity_sroa_body_shape_matches_full_pipeline() { + // Stage-parity check after Scalar Replacement of Aggregates (SROA). + // + // Invariant: After SROA, the reachable callable set, signature surface, + // and callable body shape for this source program should already match + // the later full-pipeline result. + // + // Importance: SROA is a data-flow optimization that rewrites local + // patterns and parameter decomposition. Validating parity ensures that + // the scalarization does not introduce unexpected new callables or + // remove callables unexpectedly. + let source = r#" + function Pair() : (Int, Bool) { (1, true) } + @EntryPoint() + operation Main() : Int { + let (a, _) = Pair(); + a + } + "#; + + let (mut post_sroa_store, post_sroa_pkg_id, _) = compile_and_lower(source); + let (mut full_store, full_pkg_id, _) = compile_and_lower(source); + + run_pipeline_to_successfully(&mut post_sroa_store, post_sroa_pkg_id, PipelineStage::Sroa); + run_pipeline_successfully(&mut full_store, full_pkg_id); + + invariants::check( + &post_sroa_store, + post_sroa_pkg_id, + invariants::InvariantLevel::PostSroa, + ); + + let full_package = full_store.get(full_pkg_id); + validate(full_package, &full_store); + + // Callable count parity. + let post_sroa_callables = reachable_callable_names(&post_sroa_store, post_sroa_pkg_id); + let full_callables = reachable_callable_names(&full_store, full_pkg_id); + assert_eq!( + post_sroa_callables, full_callables, + "callable set after SROA must match full pipeline" + ); + + // Type summary parity. + assert_eq!( + reachable_callable_summary(&post_sroa_store, post_sroa_pkg_id), + reachable_callable_summary(&full_store, full_pkg_id) + ); + + // Body summary parity: callable bodies should be identical after SROA + // and full pipeline. + for callable_name in &full_callables { + assert_eq!( + callable_body_summary(&post_sroa_store, post_sroa_pkg_id, callable_name), + callable_body_summary(&full_store, full_pkg_id, callable_name), + "callable '{callable_name}' body must match after SROA and full pipeline" + ); + } +} + +#[test] +fn stage_parity_item_dce_reachable_surface_matches_full_pipeline() { + // Stage-parity check after item-level dead code elimination. + // + // Invariant: After ItemDce, the reachable callable surface should match + // the full pipeline output for this program. A separate regression test + // below asserts direct removal of an unreachable callable item. + // + // Importance: ItemDce is a critical pass that removes dead code. This + // test validates that DCE correctly identifies and preserves reachable + // items while eliminating only truly dead items, avoiding premature + // removal or over-retention of items. + let source = r#" + function Unused() : Int { 99 } + function Used() : Int { 42 } + @EntryPoint() + operation Main() : Int { Used() } + "#; + + let (mut post_dce_store, post_dce_pkg_id, _) = compile_and_lower(source); + let (mut full_store, full_pkg_id, _) = compile_and_lower(source); + + run_pipeline_to_successfully(&mut post_dce_store, post_dce_pkg_id, PipelineStage::ItemDce); + run_pipeline_successfully(&mut full_store, full_pkg_id); + + invariants::check( + &post_dce_store, + post_dce_pkg_id, + invariants::InvariantLevel::PostArgPromote, // ItemDce runs after ArgPromote + ); + + let full_package = full_store.get(full_pkg_id); + validate(full_package, &full_store); + + // Callable count parity: reachable callables must match. + let post_dce_callables = reachable_callable_names(&post_dce_store, post_dce_pkg_id); + let full_callables = reachable_callable_names(&full_store, full_pkg_id); + assert_eq!( + post_dce_callables, full_callables, + "reachable callable set after ItemDce must match full pipeline" + ); + + // Type summary parity. + assert_eq!( + reachable_callable_summary(&post_dce_store, post_dce_pkg_id), + reachable_callable_summary(&full_store, full_pkg_id) + ); +} + +#[test] +fn stage_parity_exec_graph_rebuild_reconstructs_execution_graph() { + // Stage-parity check after execution graph rebuild. + // + // Invariant: After ExecGraphRebuild, the execution graph is reconstructed + // from the rewritten FIR. All EMPTY_EXEC_RANGE sentinels from earlier + // passes are replaced with valid execution graph ranges. Callable bodies + // should match the full pipeline output, and the package structure + // should be stable. + // + // Importance: ExecGraphRebuild is the final structural pass. This test + // validates that the execution graph reconstruction does not alter + // callable definitions, introduce new callables, or remove existing + // ones. The reconstructed graph should be well-formed and match the + // full pipeline's graph. + let source = r#" + operation Identity<'T>(x : 'T) : 'T { x } + operation Apply(op : Qubit => Unit, q : Qubit) : Unit { op(q); } + @EntryPoint() + operation Main() : Unit { + use q = Qubit(); + Apply(H, q); + let _ = Identity(42); + Reset(q); + } + "#; + + let (mut post_rebuild_store, post_rebuild_pkg_id, _) = compile_and_lower(source); + let (mut full_store, full_pkg_id, _) = compile_and_lower(source); + + run_pipeline_to_successfully( + &mut post_rebuild_store, + post_rebuild_pkg_id, + PipelineStage::ExecGraphRebuild, + ); + run_pipeline_successfully(&mut full_store, full_pkg_id); + + invariants::check( + &post_rebuild_store, + post_rebuild_pkg_id, + invariants::InvariantLevel::PostAll, // ExecGraphRebuild is the last pass + ); + + let full_package = full_store.get(full_pkg_id); + validate(full_package, &full_store); + + // Callable count parity. + let post_rebuild_callables = reachable_callable_names(&post_rebuild_store, post_rebuild_pkg_id); + let full_callables = reachable_callable_names(&full_store, full_pkg_id); + assert_eq!( + post_rebuild_callables, full_callables, + "callable set after ExecGraphRebuild must match full pipeline" + ); + + // Type summary parity. + assert_eq!( + reachable_callable_summary(&post_rebuild_store, post_rebuild_pkg_id), + reachable_callable_summary(&full_store, full_pkg_id) + ); + + // Body summary parity: all callable bodies must match. + for callable_name in &full_callables { + assert_eq!( + callable_body_summary(&post_rebuild_store, post_rebuild_pkg_id, callable_name), + callable_body_summary(&full_store, full_pkg_id, callable_name), + "callable '{callable_name}' body must match after ExecGraphRebuild and full pipeline" + ); + } +} + +#[test] +fn stage_parity_mono_type_stability() { + // Regression test for generic specialization at PostMono. + // + // Invariant: Generic specialization at Mono produces monomorphized callables. + // The callable count (reachable from entry) should match the full pipeline, + // indicating that Mono neither creates nor removes callables unexpectedly. + let source = r#" + operation Generic<'T>(x: 'T) : Unit { } + @EntryPoint() + operation Main() : Unit { + Generic(1); + Generic("str"); + } + "#; + + let (mut post_mono_store, post_mono_pkg_id, _) = compile_and_lower(source); + let (mut full_store, full_pkg_id, _) = compile_and_lower(source); + + run_pipeline_to_successfully(&mut post_mono_store, post_mono_pkg_id, PipelineStage::Mono); + run_pipeline_successfully(&mut full_store, full_pkg_id); + + invariants::check( + &post_mono_store, + post_mono_pkg_id, + invariants::InvariantLevel::PostMono, + ); + + let mono_callable_count = reachable_callable_names(&post_mono_store, post_mono_pkg_id).len(); + let full_callable_count = reachable_callable_names(&full_store, full_pkg_id).len(); + + assert_eq!( + mono_callable_count, full_callable_count, + "generic specialization should not change reachable callable count at PostMono" + ); +} + +#[test] +fn stage_parity_defunc_hof_elimination() { + // Regression test for HOF callable elimination at PostDefunc. + // + // Invariant: After defunctionalization, HOF callables (with arrow types) + // have been eliminated and replaced by lifted specializations. The reachable + // callable count should reflect this transformation and match the full pipeline. + let source = r#" + operation Apply(op : Qubit => Unit, q : Qubit) : Unit { op(q); } + @EntryPoint() + operation Main() : Unit { + use q = Qubit(); + Apply(H, q); + Apply(X, q); + } + "#; + + let (mut post_defunc_store, post_defunc_pkg_id, _) = compile_and_lower(source); + let (mut full_store, full_pkg_id, _) = compile_and_lower(source); + + run_pipeline_to_successfully( + &mut post_defunc_store, + post_defunc_pkg_id, + PipelineStage::Defunc, + ); + run_pipeline_successfully(&mut full_store, full_pkg_id); + + invariants::check( + &post_defunc_store, + post_defunc_pkg_id, + invariants::InvariantLevel::PostDefunc, + ); + + let defunc_callables = reachable_callable_names(&post_defunc_store, post_defunc_pkg_id); + let full_callables = reachable_callable_names(&full_store, full_pkg_id); + + assert_eq!( + defunc_callables, full_callables, + "HOF elimination should produce consistent callable set at PostDefunc" + ); +} + +#[test] +fn stage_parity_tuple_comp_lower_no_residual() { + // Regression test for tuple comparison lowering at PostTupleCompLower. + // + // Invariant: After tuple comparison lowering, no binary equality operations + // on tuple-typed operands remain in reachable code. This test verifies the + // lowering completes without introducing residual BinOp(Eq, Tuple) expressions. + let source = r#" + @EntryPoint() + operation Main() : Bool { + let pair = (1, 2); + let other = (1, 2); + pair == other + } + "#; + + let (mut post_tuple_store, post_tuple_pkg_id, _) = compile_and_lower(source); + let (mut full_store, full_pkg_id, _) = compile_and_lower(source); + + run_pipeline_to_successfully( + &mut post_tuple_store, + post_tuple_pkg_id, + PipelineStage::TupleCompLower, + ); + run_pipeline_successfully(&mut full_store, full_pkg_id); + + invariants::check( + &post_tuple_store, + post_tuple_pkg_id, + invariants::InvariantLevel::PostTupleCompLower, + ); + + let post_tuple_callables = reachable_callable_names(&post_tuple_store, post_tuple_pkg_id); + let full_callables = reachable_callable_names(&full_store, full_pkg_id); + + assert_eq!( + post_tuple_callables, full_callables, + "tuple comparison lowering should preserve callable structure" + ); +} + +#[test] +fn stage_parity_item_dce_removes_unreachable_callable_items() { + // Regression test for item DCE removing dead callable items. + // + // Invariant: After item DCE, callable items that are not reachable from + // the entry expression are removed from the package item table. + let source = r#" + operation Unused() : Unit { } + operation Used() : Unit { } + @EntryPoint() + operation Main() : Unit { Used(); } + "#; + + let (mut pre_dce_store, pre_dce_pkg_id, _) = compile_and_lower(source); + let (mut post_dce_store, post_dce_pkg_id, _) = compile_and_lower(source); + + run_pipeline_to_successfully(&mut pre_dce_store, pre_dce_pkg_id, PipelineStage::Gc); + run_pipeline_to_successfully(&mut post_dce_store, post_dce_pkg_id, PipelineStage::ItemDce); + + let pre_dce_callables = reachable_callable_names(&pre_dce_store, pre_dce_pkg_id); + let post_dce_callables = reachable_callable_names(&post_dce_store, post_dce_pkg_id); + + assert!( + package_has_callable_named(&pre_dce_store, pre_dce_pkg_id, "Unused"), + "pre-ItemDce package should still contain dead callable item 'Unused'" + ); + assert!( + post_dce_callables.len() <= pre_dce_callables.len(), + "item DCE should not increase reachable callable count" + ); + assert!( + !package_has_callable_named(&post_dce_store, post_dce_pkg_id, "Unused"), + "ItemDce should remove unreachable callable item 'Unused'" + ); + assert!( + package_has_callable_named(&post_dce_store, post_dce_pkg_id, "Used"), + "ItemDce should keep reachable callable item 'Used'" + ); + assert!( + package_has_callable_named(&post_dce_store, post_dce_pkg_id, "Main"), + "ItemDce should keep the entry callable item 'Main'" + ); + + invariants::check( + &post_dce_store, + post_dce_pkg_id, + invariants::InvariantLevel::PostGc, + ); +} diff --git a/source/compiler/qsc_frontend/src/closure.rs b/source/compiler/qsc_frontend/src/closure.rs index cf7c5ac605..19726d9d7f 100644 --- a/source/compiler/qsc_frontend/src/closure.rs +++ b/source/compiler/qsc_frontend/src/closure.rs @@ -318,6 +318,26 @@ pub(super) fn partial_app_tuple( (expr, PartialApp { bindings, input }) } +/// Creates the input pattern for a lifted closure callable. +/// +/// For non-zero captures, the result is `PatKind::Tuple(captures ++ [input])` with +/// `Ty::Tuple(capture_tys ++ [input_ty])`, which is the standard closure calling convention: +/// fixed captures are prepended to the user's input. +/// +/// For zero captures, the result is still `PatKind::Tuple([input])` with `Ty::Tuple([input_ty])`. +/// This 1-tuple wrapping is an intentional convention — **not** incidental — and multiple +/// downstream passes depend on it: +/// +/// - `direct_lambda_packaged_input` (defunc rewrite) detects zero-capture lambdas by matching +/// `Ty::Tuple(items) if items.len() == 1` +/// - `rewrite_direct_closure_args` wraps call-site arguments in `Tuple([args])` to match +/// - `map_input_pattern_to_input_expressions` (RCA) uses `skip_ahead` logic assuming the 1-tuple +/// - `merge_fixed_args` (eval) wraps `Value::Tuple([arg])` for `Some([])` +/// - `resolve_args` (partial eval) has a fallback for post-defunc mismatches +/// +/// Changing this to return bare `input` for zero captures requires coordinated updates +/// across all five sites: `direct_lambda_packaged_input`, `rewrite_direct_closure_args`, +/// `map_input_pattern_to_input_expressions`, `merge_fixed_args`, and `resolve_args`. fn closure_input( vars: impl IntoIterator, input: Pat, diff --git a/source/compiler/qsc_frontend/src/lower/tests.rs b/source/compiler/qsc_frontend/src/lower/tests.rs index fae64744e2..9ce35e0960 100644 --- a/source/compiler/qsc_frontend/src/lower/tests.rs +++ b/source/compiler/qsc_frontend/src/lower/tests.rs @@ -7,6 +7,10 @@ use indoc::indoc; use qsc_data_structures::{ language_features::LanguageFeatures, source::SourceMap, target::TargetCapabilityFlags, }; +use qsc_hir::{ + hir::{ItemKind, SpecBody, StmtKind}, + ty::{Prim, Ty}, +}; fn check_hir(input: &str, expect: &Expect) { let sources = SourceMap::new([("test".into(), input.into())], None); @@ -20,6 +24,17 @@ fn check_hir(input: &str, expect: &Expect) { expect.assert_eq(&unit.package.to_string()); } +fn compile_unit(input: &str) -> compile::CompileUnit { + let sources = SourceMap::new([("test".into(), input.into())], None); + compile( + &PackageStore::new(compile::core()), + &[], + sources, + TargetCapabilityFlags::all(), + LanguageFeatures::default(), + ) +} + fn check_errors(input: &str, expect: &Expect) { let sources = SourceMap::new([("test".into(), input.into())], None); let unit = compile( @@ -258,6 +273,74 @@ fn lift_local_operation() { ); } +#[test] +fn explicit_qubit_annotation_preserves_type_through_resolution_and_lowering() { + let unit = compile_unit(indoc! {" + namespace input { + operation Foo() : Unit { + use q : Qubit = Qubit(); + let x = 3; + } + } + "}); + assert!(unit.errors.is_empty(), "{:?}", unit.errors); + + let namespace = unit + .ast + .package + .nodes + .iter() + .find_map(|node| match node { + qsc_ast::ast::TopLevelNode::Namespace(namespace) => Some(namespace), + qsc_ast::ast::TopLevelNode::Stmt(_) => None, + }) + .expect("namespace should exist"); + let ast_callable = namespace + .items + .iter() + .find_map(|item| match &*item.kind { + qsc_ast::ast::ItemKind::Callable(callable) if callable.name.name.as_ref() == "Foo" => { + Some(callable) + } + _ => None, + }) + .expect("Foo AST callable should exist"); + let qsc_ast::ast::CallableBody::Block(ast_block) = &*ast_callable.body else { + panic!("Foo AST callable should have a block body"); + }; + let qsc_ast::ast::StmtKind::Qubit(_, ast_pat, _, None) = &*ast_block.stmts[0].kind else { + panic!("first AST statement should be the qubit allocation"); + }; + let qsc_ast::ast::PatKind::Bind(_, Some(_)) = &*ast_pat.kind else { + panic!("AST qubit pattern should retain the explicit annotation"); + }; + + assert_eq!( + unit.ast.tys.terms.get(ast_pat.id), + Some(&Ty::Prim(Prim::Qubit)), + "type table should preserve the resolved explicit qubit annotation" + ); + + let callable = unit + .package + .items + .values() + .find_map(|item| match &item.kind { + ItemKind::Callable(callable) if callable.name.name.as_ref() == "Foo" => Some(callable), + _ => None, + }) + .expect("Foo callable should exist"); + let SpecBody::Impl(_, block) = &callable.body.body else { + panic!("Foo should have an implementation body"); + }; + let StmtKind::Qubit(_, pat, init, None) = &block.stmts[0].kind else { + panic!("first statement should be the raw qubit allocation"); + }; + + assert_eq!(pat.ty, Ty::Prim(Prim::Qubit)); + assert_eq!(init.ty, Ty::Prim(Prim::Qubit)); +} + #[test] fn lift_local_newtype() { check_hir( diff --git a/source/compiler/qsc_frontend/src/resolve.rs b/source/compiler/qsc_frontend/src/resolve.rs index a4d3be96ce..40fba4102d 100644 --- a/source/compiler/qsc_frontend/src/resolve.rs +++ b/source/compiler/qsc_frontend/src/resolve.rs @@ -1125,6 +1125,7 @@ impl AstVisitor<'_> for With<'_> { self.resolver.bind_pat(pat, stmt.span.hi); } ast::StmtKind::Qubit(_, pat, init, block) => { + self.visit_pat(pat); ast_visit::walk_qubit_init(self, init); if let Some(block) = block { self.with_pat(block.span, ScopeKind::Block, pat, |visitor| { diff --git a/source/compiler/qsc_lowerer/src/lib.rs b/source/compiler/qsc_lowerer/src/lib.rs index 10b035d4d1..0acb354a20 100644 --- a/source/compiler/qsc_lowerer/src/lib.rs +++ b/source/compiler/qsc_lowerer/src/lib.rs @@ -52,7 +52,7 @@ pub struct ExecGraphBuilder { impl ExecGraphBuilder { /// Takes the built execution graph and resets the builder. - fn take(&mut self) -> ExecGraph { + pub fn take(&mut self) -> ExecGraph { let debug_exec_graph = self .debug .drain(..) @@ -71,18 +71,18 @@ impl ExecGraphBuilder { } /// Pushes a node to *only* the debug execution graph. - fn debug_push(&mut self, node: ExecGraphDebugNode) { + pub fn debug_push(&mut self, node: ExecGraphDebugNode) { self.debug.push(ExecGraphNode::Debug(node)); } /// Pushes a node to the execution graph. - fn push(&mut self, node: ExecGraphNode) { + pub fn push(&mut self, node: ExecGraphNode) { self.no_debug.push(node); self.debug.push(node); } /// Constructs a node with the given argument, then pushes it to the execution graph. - fn push_with_arg(&mut self, node_fn: F, arg: ExecGraphIdx) + pub fn push_with_arg(&mut self, node_fn: F, arg: ExecGraphIdx) where F: Fn(u32) -> ExecGraphNode, { @@ -98,14 +98,15 @@ impl ExecGraphBuilder { } /// Pushes a return node to the execution graph. - fn push_ret(&mut self) { + pub fn push_ret(&mut self) { self.no_debug.push(ExecGraphNode::Ret); self.debug .push(ExecGraphNode::Debug(ExecGraphDebugNode::RetFrame)); } /// Returns the current length of the execution graph. - fn len(&self) -> ExecGraphIdx { + #[must_use] + pub fn len(&self) -> ExecGraphIdx { ExecGraphIdx { no_debug_idx: self.no_debug.len(), debug_idx: self.debug.len(), @@ -113,7 +114,7 @@ impl ExecGraphBuilder { } /// Constructs a node with the given argument, then sets it at the given index in the execution graph. - fn set_with_arg(&mut self, node_fn: F, index: ExecGraphIdx, arg: ExecGraphIdx) + pub fn set_with_arg(&mut self, node_fn: F, index: ExecGraphIdx, arg: ExecGraphIdx) where F: Fn(u32) -> ExecGraphNode, { @@ -131,13 +132,13 @@ impl ExecGraphBuilder { } /// Removes all nodes after and including the given index. - fn truncate(&mut self, idx: ExecGraphIdx) { + pub fn truncate(&mut self, idx: ExecGraphIdx) { self.no_debug.truncate(idx.no_debug_idx); self.debug.truncate(idx.debug_idx); } /// Removes the last pushed node. - fn pop(&mut self) { + pub fn pop(&mut self) { self.no_debug.pop(); self.debug.pop(); } @@ -181,6 +182,13 @@ impl Lowerer { self.exec_graph.take() } + /// Consumes the lowerer and returns the Assigner with watermarks + /// representing one-past-max for every ID category. + #[must_use] + pub fn into_assigner(self) -> Assigner { + self.assigner + } + pub fn lower_package( &mut self, package: &hir::Package, diff --git a/source/compiler/qsc_openqasm_compiler/src/compiler.rs b/source/compiler/qsc_openqasm_compiler/src/compiler.rs index c365c00ee9..cd33268d11 100644 --- a/source/compiler/qsc_openqasm_compiler/src/compiler.rs +++ b/source/compiler/qsc_openqasm_compiler/src/compiler.rs @@ -281,13 +281,12 @@ impl QasmCompiler { ) } - /// Gets the profile for compilation from the first profile - /// pragma if present, otherwise default to `Unrestricted`. - fn get_profile(&self) -> Profile { + /// Extracts the QIR profile from `OpenQASM` pragmas. + fn get_profile(&self) -> Option { self.pragma_config .pragmas .get(&PragmaKind::QdkQirProfile) - .map_or(Profile::Unrestricted, |profile_str| { + .map(|profile_str| { Profile::from_str(profile_str.as_ref()).expect( "Invalid profile pragma; only a valid profile should be store in pragma_config.", ) diff --git a/source/compiler/qsc_openqasm_compiler/src/lib.rs b/source/compiler/qsc_openqasm_compiler/src/lib.rs index 96dd3bbbdd..8167b86ad7 100644 --- a/source/compiler/qsc_openqasm_compiler/src/lib.rs +++ b/source/compiler/qsc_openqasm_compiler/src/lib.rs @@ -253,10 +253,9 @@ pub struct QasmCompileUnit { /// The signature of the operation created from the QASM source code. /// None if the program type is `ProgramType::Fragments`. signature: Option, - /// The QIR profile used for the compilation. - /// This is used to determine the QIR profile that the generated code - /// will use. - profile: Profile, + /// The QIR profile for compilation, derived from pragmas. + /// Returns `None` if no profile pragma was specified in the `OpenQASM` source. + profile: Option, } /// Represents a QASM compilation unit. @@ -270,7 +269,7 @@ impl QasmCompileUnit { errors: Vec>, package: Package, signature: Option, - profile: Profile, + profile: Option, ) -> Self { Self { source_map, @@ -293,9 +292,9 @@ impl QasmCompileUnit { self.errors.clone() } - /// Returns the QIR target profile associated with the compilation unit. + /// Returns the optional QIR profile from `OpenQASM` pragmas. #[must_use] - pub fn profile(&self) -> Profile { + pub fn profile(&self) -> Option { self.profile } @@ -308,7 +307,7 @@ impl QasmCompileUnit { Vec>, Package, Option, - Profile, + Option, ) { ( self.source_map, diff --git a/source/compiler/qsc_openqasm_compiler/src/tests.rs b/source/compiler/qsc_openqasm_compiler/src/tests.rs index e9b08bfb98..18d883ce19 100644 --- a/source/compiler/qsc_openqasm_compiler/src/tests.rs +++ b/source/compiler/qsc_openqasm_compiler/src/tests.rs @@ -188,7 +188,12 @@ fn compile_qasm_to_qir(source: &str) -> Result> { let unit = compile(source)?; fail_on_compilation_errors(&unit); let package = unit.package; - let qir = generate_qir_from_ast(package, unit.source_map, unit.profile).map_err(|errors| { + let qir = generate_qir_from_ast( + package, + unit.source_map, + unit.profile.unwrap_or(Profile::Unrestricted), + ) + .map_err(|errors| { errors .iter() .map(|e| Report::new(e.clone())) @@ -216,6 +221,7 @@ fn compile_qasm_best_effort(source: &str) { config, ); let (sources, _, package, _, profile) = unit.into_tuple(); + let profile = profile.unwrap_or(Profile::Unrestricted); let (stdid, store) = package_store_with_stdlib(profile.into()); let dependencies = vec![(PackageId::CORE, None), (stdid, None)]; @@ -413,7 +419,7 @@ fn verify_qsharp_from_qasm_source( /// Verifies a Q# AST package (with namespaces) compiles through the Q# compiler. fn verify_qsharp_ast(unit: &QasmCompileUnit) -> miette::Result<(), Vec> { - let capabilities = unit.profile.into(); + let capabilities = unit.profile.unwrap_or(Profile::Unrestricted).into(); let (stdid, store) = package_store_with_stdlib(capabilities); let dependencies = vec![(PackageId::CORE, None), (stdid, None)]; let (_compiled, errors) = compile_ast( diff --git a/source/compiler/qsc_partial_eval/Cargo.toml b/source/compiler/qsc_partial_eval/Cargo.toml index 0488a38db3..b14168693e 100644 --- a/source/compiler/qsc_partial_eval/Cargo.toml +++ b/source/compiler/qsc_partial_eval/Cargo.toml @@ -26,6 +26,7 @@ expect-test = { workspace = true } indoc = { workspace = true } qsc = { path = "../qsc" } qsc_frontend = { path = "../qsc_frontend" } +qsc_passes = { path = "../qsc_passes" } [lints] workspace = true diff --git a/source/compiler/qsc_partial_eval/src/evaluation_context.rs b/source/compiler/qsc_partial_eval/src/evaluation_context.rs index 02e8e5e06d..429db4b1f4 100644 --- a/source/compiler/qsc_partial_eval/src/evaluation_context.rs +++ b/source/compiler/qsc_partial_eval/src/evaluation_context.rs @@ -274,6 +274,12 @@ impl Arg { } /// Represents the possible control flow options that an evaluation can have. +/// +/// Note: The `Return` variant is vestigial for the production pipeline. +/// The `return_unify` FIR transform pass eliminates all `ExprKind::Return` +/// nodes before partial evaluation runs. However, partial eval unit tests +/// bypass FIR transforms and evaluate raw FIR, so the `Return` variant +/// and its handling code remain for test compatibility. pub enum EvalControlFlow { Continue(Value), Return(Value), diff --git a/source/compiler/qsc_partial_eval/src/lib.rs b/source/compiler/qsc_partial_eval/src/lib.rs index 0994287cc9..947557a089 100644 --- a/source/compiler/qsc_partial_eval/src/lib.rs +++ b/source/compiler/qsc_partial_eval/src/lib.rs @@ -2576,7 +2576,7 @@ impl<'a> PartialEvaluator<'a> { let bin_op_variable_id = self.resource_manager.next_var(); let bin_op_rir_variable = match bin_op { - BinOp::Add | BinOp::Sub | BinOp::Mul | BinOp::Div => { + BinOp::Add | BinOp::Sub | BinOp::Mul | BinOp::Div | BinOp::Mod => { rir::Variable::new_double(bin_op_variable_id) } BinOp::Eq | BinOp::Neq | BinOp::Gt | BinOp::Gte | BinOp::Lt | BinOp::Lte => { @@ -2598,6 +2598,14 @@ impl<'a> PartialEvaluator<'a> { Instruction::Fdiv(lhs_operand, rhs_operand, bin_op_rir_variable) } + BinOp::Mod => { + if let Operand::Literal(Literal::Double(0.0)) = rhs_operand { + let error = EvalError::DivZero(bin_op_expr_span).into(); + return Err(error); + } + + Instruction::Frem(lhs_operand, rhs_operand, bin_op_rir_variable) + } BinOp::Eq => Instruction::Fcmp( FcmpConditionCode::OrderedAndEqual, lhs_operand, diff --git a/source/compiler/qsc_partial_eval/src/tests.rs b/source/compiler/qsc_partial_eval/src/tests.rs index d32db7eb31..3c3fadff24 100644 --- a/source/compiler/qsc_partial_eval/src/tests.rs +++ b/source/compiler/qsc_partial_eval/src/tests.rs @@ -27,8 +27,7 @@ use qsc_data_structures::{ target::{Profile, TargetCapabilityFlags}, }; use qsc_fir::fir::PackageStore; -use qsc_frontend::compile::PackageStore as HirPackageStore; -use qsc_lowerer::{Lowerer, map_hir_package_to_fir}; +use qsc_passes::lower_hir_to_fir; use qsc_rca::{Analyzer, PackageStoreComputeProperties}; use qsc_rir::{ passes::check_and_transform, @@ -216,8 +215,8 @@ impl CompilationContext { &[(std_id, None)], ) .expect("should be able to create a new compiler"); - let package_id = map_hir_package_to_fir(compiler.source_package_id()); - let fir_store = lower_hir_package_store(compiler.package_store()); + let (fir_store, package_id, _) = + lower_hir_to_fir(compiler.package_store(), compiler.source_package_id()); let analyzer = Analyzer::init(&fir_store, capabilities); let compute_properties = analyzer.analyze_all(); let package = fir_store.get(package_id); @@ -239,13 +238,3 @@ impl CompilationContext { } } } - -fn lower_hir_package_store(hir_package_store: &HirPackageStore) -> PackageStore { - let mut fir_store = PackageStore::new(); - for (id, unit) in hir_package_store { - let mut lowerer = Lowerer::new(); - let lowered_package = lowerer.lower_package(&unit.package, &fir_store); - fir_store.insert(map_hir_package_to_fir(id), lowered_package); - } - fir_store -} diff --git a/source/compiler/qsc_partial_eval/src/tests/loops.rs b/source/compiler/qsc_partial_eval/src/tests/loops.rs index db07f4d246..2c08eb5a45 100644 --- a/source/compiler/qsc_partial_eval/src/tests/loops.rs +++ b/source/compiler/qsc_partial_eval/src/tests/loops.rs @@ -1801,3 +1801,58 @@ fn dynamic_nested_loop() { Jump(7)"#]], ); } + +#[test] +fn classical_while_inside_dynamic_while_folds_mutable_variable() { + // Verifies that a classically-unrolled while loop nested inside the body of a + // dynamic (emit) while loop correctly folds mutable variables to their static + // values instead of treating them as dynamic variables. + let program = get_rir_program_with_capabilities( + indoc! { + r#" + namespace Test { + @EntryPoint() + operation Main() : Int { + use q = Qubit(); + mutable total = 0; + while MResetZ(q) == One { + mutable i = 0; + while i < 3 { + set i += 1; + } + set total += i; + } + total + } + } + "#, + }, + TargetCapabilityFlags::Adaptive | TargetCapabilityFlags::BackwardsBranching, + ); + + // The inner `while i < 3` loop should be fully unrolled classically, + // and `i` should fold to 3. The outer loop emits branch instructions. + assert_blocks( + &program, + &expect![[r#" + Blocks: + Block 0:Block: + Call id(1), args( Pointer, ) + Variable(0, Integer) = Store Integer(0) + Jump(1) + Block 1:Block: + Call id(2), args( Qubit(0), Result(0), ) + Variable(1, Boolean) = Call id(3), args( Result(0), ) + Variable(2, Boolean) = Store Variable(1, Boolean) + Branch Variable(2, Boolean), 3, 2 + Block 2:Block: + Variable(5, Integer) = Store Variable(0, Integer) + Call id(4), args( Variable(5, Integer), Tag(0, 3), ) + Return + Block 3:Block: + Variable(3, Integer) = Store Integer(0) + Variable(4, Integer) = Add Variable(0, Integer), Integer(3) + Variable(0, Integer) = Store Variable(4, Integer) + Jump(1)"#]], + ); +} diff --git a/source/compiler/qsc_partial_eval/src/tests/operators.rs b/source/compiler/qsc_partial_eval/src/tests/operators.rs index 5e733ebb5b..449256ace5 100644 --- a/source/compiler/qsc_partial_eval/src/tests/operators.rs +++ b/source/compiler/qsc_partial_eval/src/tests/operators.rs @@ -2137,6 +2137,28 @@ fn integer_exponentiation_with_lhs_classical_integer_and_rhs_classical_negative_ ); } +#[test] +fn integer_exponentiation_with_both_classical_and_rhs_negative_raises_error() { + let error = get_partial_evaluation_error(indoc! { + r#" + namespace Test { + @EntryPoint() + operation Main() : Int { + use q = Qubit(); + let _ = MResetZ(q); + 2 ^ -3 + } + } + "#, + }); + assert_error( + &error, + &expect![[ + r#"EvaluationFailed("negative integers cannot be used here: -3", PackageSpan { package: PackageId(2), span: Span { lo: 130, hi: 132 } })"# + ]], + ); +} + #[test] fn integer_exponentiation_with_lhs_dynamic_integer_and_rhs_classical_zero_integer() { let program = get_rir_program(indoc! { @@ -2821,6 +2843,85 @@ fn integer_equality_comparison_with_lhs_dynamic_integer_and_rhs_classical_intege ); } +#[test] +fn integer_equality_comparison_after_dynamic_mutation_is_not_constant_folded() { + let program = get_rir_program(indoc! { + r#" + namespace Test { + @EntryPoint() + operation Main() : Bool { + use q = Qubit(); + mutable count = 0; + if MResetZ(q) == One { + set count += 1; + } + count == 1 + } + } + "#, + }); + let measurement_callable_id = CallableId(1); + assert_callable( + &program, + measurement_callable_id, + &expect![[r#" + Callable: + name: __quantum__rt__initialize + call_type: Regular + input_type: + [0]: Pointer + output_type: + body: "#]], + ); + let readout_callable_id = CallableId(2); + assert_callable( + &program, + readout_callable_id, + &expect![[r#" + Callable: + name: __quantum__qis__mresetz__body + call_type: Measurement + input_type: + [0]: Qubit + [1]: Result + output_type: + body: "#]], + ); + let output_record_id = CallableId(3); + assert_callable( + &program, + output_record_id, + &expect![[r#" + Callable: + name: __quantum__rt__read_result + call_type: Readout + input_type: + [0]: Result + output_type: Boolean + body: "#]], + ); + assert_blocks( + &program, + &expect![[r#" + Blocks: + Block 0:Block: + Call id(1), args( Pointer, ) + Variable(0, Integer) = Store Integer(0) + Call id(2), args( Qubit(0), Result(0), ) + Variable(1, Boolean) = Call id(3), args( Result(0), ) + Variable(2, Boolean) = Store Variable(1, Boolean) + Branch Variable(2, Boolean), 2, 1 + Block 1:Block: + Variable(3, Boolean) = Icmp Eq, Variable(0, Integer), Integer(1) + Variable(4, Boolean) = Store Variable(3, Boolean) + Call id(4), args( Variable(4, Boolean), Tag(0, 3), ) + Return + Block 2:Block: + Variable(0, Integer) = Store Integer(1) + Jump(1)"#]], + ); +} + #[test] fn integer_inequality_comparison_with_lhs_dynamic_integer_and_rhs_dynamic_integer() { let program = get_rir_program(indoc! { @@ -4245,3 +4346,82 @@ fn double_less_or_equal_than_comparison_with_lhs_classical_double_and_rhs_dynami Jump(1)"#]], ); } + +#[test] +fn double_mod_with_lhs_dynamic_double_and_rhs_classical_double() { + let program = get_rir_program(indoc! { + r#" + namespace Test { + @EntryPoint() + operation Main() : Double { + use q = Qubit(); + let i = MResetZ(q) == Zero ? 0.0 | 1.0; + i % 2.0 + } + } + "#, + }); + let measurement_callable_id = CallableId(1); + assert_callable( + &program, + measurement_callable_id, + &expect![[r#" + Callable: + name: __quantum__rt__initialize + call_type: Regular + input_type: + [0]: Pointer + output_type: + body: "#]], + ); + let readout_callable_id = CallableId(2); + assert_callable( + &program, + readout_callable_id, + &expect![[r#" + Callable: + name: __quantum__qis__mresetz__body + call_type: Measurement + input_type: + [0]: Qubit + [1]: Result + output_type: + body: "#]], + ); + let output_record_id = CallableId(3); + assert_callable( + &program, + output_record_id, + &expect![[r#" + Callable: + name: __quantum__rt__read_result + call_type: Readout + input_type: + [0]: Result + output_type: Boolean + body: "#]], + ); + assert_blocks( + &program, + &expect![[r#" + Blocks: + Block 0:Block: + Call id(1), args( Pointer, ) + Call id(2), args( Qubit(0), Result(0), ) + Variable(0, Boolean) = Call id(3), args( Result(0), ) + Variable(1, Boolean) = Icmp Eq, Variable(0, Boolean), Bool(false) + Branch Variable(1, Boolean), 2, 3 + Block 1:Block: + Variable(3, Double) = Store Variable(2, Double) + Variable(4, Double) = Frem Variable(3, Double), Double(2) + Variable(5, Double) = Store Variable(4, Double) + Call id(4), args( Variable(5, Double), Tag(0, 3), ) + Return + Block 2:Block: + Variable(2, Double) = Store Double(0) + Jump(1) + Block 3:Block: + Variable(2, Double) = Store Double(1) + Jump(1)"#]], + ); +} diff --git a/source/compiler/qsc_partial_eval/src/tests/qubits.rs b/source/compiler/qsc_partial_eval/src/tests/qubits.rs index fae82b7398..510b82ce37 100644 --- a/source/compiler/qsc_partial_eval/src/tests/qubits.rs +++ b/source/compiler/qsc_partial_eval/src/tests/qubits.rs @@ -310,6 +310,136 @@ fn qubit_array_allocation_and_access() { assert_eq!(program.num_results, 0); } +#[test] +fn qubit_array_length_is_preserved() { + let program = get_rir_program(indoc! { + r#" + namespace Test { + @EntryPoint() + operation Main() : Int { + use qs = Qubit[4]; + Length(qs) + } + } + "#, + }); + assert_block_instructions( + &program, + BlockId(0), + &expect![[r#" + Block: + Call id(1), args( Pointer, ) + Variable(0, Integer) = Store Integer(0) + Variable(0, Integer) = Store Integer(1) + Variable(0, Integer) = Store Integer(2) + Variable(0, Integer) = Store Integer(3) + Variable(0, Integer) = Store Integer(4) + Call id(2), args( Integer(4), Tag(0, 3), ) + Return"#]], + ); + assert_eq!(program.num_qubits, 4); + assert_eq!(program.num_results, 0); +} + +#[test] +fn qubit_array_chunks_can_be_indexed() { + let program = get_rir_program(indoc! { + r#" + namespace Test { + import Std.Arrays.*; + + operation Op(q : Qubit) : Unit { body intrinsic; } + + @EntryPoint() + operation Main() : Unit { + use qs = Qubit[4]; + let chunks = Chunks(2, qs); + Op(chunks[0][0]); + Op(chunks[1][1]); + } + } + "#, + }); + let op_callable_id = CallableId(1); + assert_callable( + &program, + op_callable_id, + &expect![[r#" + Callable: + name: __quantum__rt__initialize + call_type: Regular + input_type: + [0]: Pointer + output_type: + body: "#]], + ); + let tuple_callable_id = CallableId(2); + assert_callable( + &program, + tuple_callable_id, + &expect![[r#" + Callable: + name: Op + call_type: Regular + input_type: + [0]: Qubit + output_type: + body: "#]], + ); + assert_block_instructions( + &program, + BlockId(0), + &expect![[r#" + Block: + Call id(1), args( Pointer, ) + Variable(0, Integer) = Store Integer(0) + Variable(0, Integer) = Store Integer(1) + Variable(0, Integer) = Store Integer(2) + Variable(0, Integer) = Store Integer(3) + Variable(0, Integer) = Store Integer(4) + Call id(2), args( Qubit(0), ) + Call id(2), args( Qubit(3), ) + Call id(3), args( Integer(0), Tag(0, 3), ) + Return"#]], + ); + assert_eq!(program.num_qubits, 4); + assert_eq!(program.num_results, 0); +} + +#[test] +fn qubit_array_chunk_count_is_preserved() { + let program = get_rir_program(indoc! { + r#" + namespace Test { + import Std.Arrays.*; + + @EntryPoint() + operation Main() : Int { + use qs = Qubit[4]; + let chunks = Chunks(2, qs); + Length(chunks) + } + } + "#, + }); + assert_block_instructions( + &program, + BlockId(0), + &expect![[r#" + Block: + Call id(1), args( Pointer, ) + Variable(0, Integer) = Store Integer(0) + Variable(0, Integer) = Store Integer(1) + Variable(0, Integer) = Store Integer(2) + Variable(0, Integer) = Store Integer(3) + Variable(0, Integer) = Store Integer(4) + Call id(2), args( Integer(2), Tag(0, 3), ) + Return"#]], + ); + assert_eq!(program.num_qubits, 4); + assert_eq!(program.num_results, 0); +} + #[test] fn qubit_escaping_scope_triggers_runtime_error() { let error = get_partial_evaluation_error(indoc! { diff --git a/source/compiler/qsc_passes/src/capabilitiesck.rs b/source/compiler/qsc_passes/src/capabilitiesck.rs index 2bf707903d..2efc3d7c8b 100644 --- a/source/compiler/qsc_passes/src/capabilitiesck.rs +++ b/source/compiler/qsc_passes/src/capabilitiesck.rs @@ -24,7 +24,7 @@ use qsc_fir::{ Item, ItemKind, LocalItemId, LocalVarId, Package, PackageLookup, Pat, PatId, PatKind, Res, SpecDecl, SpecImpl, Stmt, StmtId, StmtKind, }, - ty::FunctorSetValue, + ty::{FunctorSetValue, Prim, Ty}, visit::{Visitor, walk_callable_decl}, }; @@ -37,15 +37,21 @@ use qsc_rca::{ use rustc_hash::FxHashMap; /// Lower a package store from `qsc_frontend` HIR store to a `qsc_fir` FIR store. +/// +/// Returns the FIR store and the `Assigner` from the final (user) package +/// lowering. The Assigner watermarks are past all IDs produced during lowering. pub fn lower_store( package_store: &qsc_frontend::compile::PackageStore, -) -> qsc_fir::fir::PackageStore { +) -> (qsc_fir::fir::PackageStore, qsc_fir::assigner::Assigner) { let mut fir_store = qsc_fir::fir::PackageStore::new(); + let mut last_assigner = qsc_fir::assigner::Assigner::new(); for (id, unit) in package_store { - let package = qsc_lowerer::Lowerer::new().lower_package(&unit.package, &fir_store); + let mut lowerer = qsc_lowerer::Lowerer::new(); + let package = lowerer.lower_package(&unit.package, &fir_store); fir_store.insert(map_hir_package_to_fir(id), package); + last_assigner = lowerer.into_assigner(); } - fir_store + (fir_store, last_assigner) } pub fn run_rca_pass( @@ -95,6 +101,31 @@ pub fn check_supported_capabilities( checker.check_all() } +/// Checks whether a single callable's runtime features are supported by the target capabilities. +/// +/// Returns capability-check errors for any expressions within the callable that require +/// runtime features exceeding `capabilities`. Returns an empty vector if the callable +/// was removed by DCE, is not a callable item, or uses no unsupported features. +#[must_use] +pub fn check_supported_capabilities_for_callable( + package: &Package, + compute_properties: &PackageComputeProperties, + callable: LocalItemId, + capabilities: TargetCapabilityFlags, + store: &qsc_fir::fir::PackageStore, +) -> Vec { + let checker = Checker { + package, + compute_properties, + target_capabilities: capabilities, + current_callable: None, + missing_features_map: FxHashMap::::default(), + store, + }; + + checker.check_callable(callable) +} + struct Checker<'a> { package: &'a Package, compute_properties: &'a PackageComputeProperties, @@ -212,6 +243,24 @@ impl<'a> Checker<'a> { self.generate_errors() } + pub fn check_callable(mut self, callable: LocalItemId) -> Vec { + let Some(current_callable) = self.package.get_global(callable) else { + // Item was removed by DCE (e.g., original generic after monomorphization). + return self.generate_errors(); + }; + let Global::Callable(callable_decl) = current_callable else { + // Non-callable item — nothing to check. + return self.generate_errors(); + }; + + self.set_current_callable(callable); + self.visit_callable_decl(callable_decl); + let callable_id = self.clear_current_callable(); + assert!(callable == callable_id); + self.check_callable_output(callable_decl); + self.generate_errors() + } + fn check_entry_expr(&mut self, expr_id: ExprId) { let expr = self.get_expr(expr_id); if expr.span == Span::default() { @@ -384,6 +433,19 @@ impl<'a> Checker<'a> { } } + fn check_callable_output(&mut self, callable_decl: &CallableDecl) { + let missing_features = get_missing_runtime_features( + output_recording_runtime_features_for_ty(&callable_decl.output), + self.target_capabilities, + ) & RuntimeFeatureFlags::output_recording_flags(); + if !missing_features.is_empty() { + self.missing_features_map + .entry(callable_decl.name.span) + .and_modify(|f| *f |= missing_features) + .or_insert(missing_features); + } + } + fn clear_current_callable(&mut self) -> LocalItemId { self.current_callable .take() @@ -449,3 +511,36 @@ fn get_spec_level_runtime_features(runtime_features: RuntimeFeatureFlags) -> Run RuntimeFeatureFlags::CyclicOperationSpec; runtime_features & SPEC_LEVEL_RUNTIME_FEATURES } + +fn output_recording_runtime_features_for_ty(ty: &Ty) -> RuntimeFeatureFlags { + match ty { + Ty::Array(item) => output_recording_runtime_features_for_ty(item), + Ty::Prim(prim) => output_recording_runtime_features_for_prim(*prim), + Ty::Tuple(items) => items + .iter() + .fold(RuntimeFeatureFlags::empty(), |features, item| { + features | output_recording_runtime_features_for_ty(item) + }), + Ty::Arrow(_) | Ty::Udt(_) => RuntimeFeatureFlags::UseOfAdvancedOutput, + Ty::Infer(_) => panic!("cannot derive runtime features for `Infer` type"), + Ty::Param(_) => panic!("cannot derive runtime features for `Param` type"), + Ty::Err => panic!("cannot derive runtime features for `Err` type"), + } +} + +fn output_recording_runtime_features_for_prim(prim: Prim) -> RuntimeFeatureFlags { + match prim { + Prim::Bool => RuntimeFeatureFlags::UseOfBoolOutput, + Prim::Double => RuntimeFeatureFlags::UseOfDoubleOutput, + Prim::Int => RuntimeFeatureFlags::UseOfIntOutput, + Prim::Result => RuntimeFeatureFlags::empty(), + Prim::BigInt + | Prim::Pauli + | Prim::Qubit + | Prim::Range + | Prim::RangeFrom + | Prim::RangeTo + | Prim::RangeFull + | Prim::String => RuntimeFeatureFlags::UseOfAdvancedOutput, + } +} diff --git a/source/compiler/qsc_passes/src/lib.rs b/source/compiler/qsc_passes/src/lib.rs index d5ffb36611..9f81155728 100644 --- a/source/compiler/qsc_passes/src/lib.rs +++ b/source/compiler/qsc_passes/src/lib.rs @@ -20,7 +20,10 @@ mod spec_gen; mod test_attribute; use callable_limits::CallableLimits; -use capabilitiesck::{check_supported_capabilities, lower_store, run_rca_pass}; +use capabilitiesck::{ + check_supported_capabilities, check_supported_capabilities_for_callable, lower_store, + run_rca_pass, +}; use entry_point::generate_entry_expr; use index_assignment::ConvertToWSlash; use loop_unification::LoopUni; @@ -70,10 +73,14 @@ pub enum PackageType { pub fn lower_hir_to_fir( package_store: &qsc_frontend::compile::PackageStore, package_id: qsc_hir::hir::PackageId, -) -> (fir::PackageStore, fir::PackageId) { - let fir_store = lower_store(package_store); +) -> ( + fir::PackageStore, + fir::PackageId, + qsc_fir::assigner::Assigner, +) { + let (fir_store, assigner) = lower_store(package_store); let fir_package_id = map_hir_package_to_fir(package_id); - (fir_store, fir_package_id) + (fir_store, fir_package_id, assigner) } pub struct PassContext { @@ -190,7 +197,7 @@ pub fn run_core_passes(core: &mut CompileUnit) -> Vec { borrow_errors.into_iter().map(Error::BorrowCk).collect() } -pub fn run_fir_passes( +pub fn run_rca( package: &fir::Package, compute_properties: &PackageComputeProperties, capabilities: TargetCapabilityFlags, @@ -203,3 +210,25 @@ pub fn run_fir_passes( .map(Error::CapabilitiesCk) .collect() } + +pub fn run_rca_for_callable( + fir_store: &fir::PackageStore, + compute_properties: &PackageStoreComputeProperties, + callable: fir::StoreItemId, + capabilities: TargetCapabilityFlags, +) -> Vec { + let package = fir_store.get(callable.package); + let package_compute_properties = compute_properties.get(callable.package); + let capabilities_errors = check_supported_capabilities_for_callable( + package, + package_compute_properties, + callable.item, + capabilities, + fir_store, + ); + + capabilities_errors + .into_iter() + .map(Error::CapabilitiesCk) + .collect() +} diff --git a/source/compiler/qsc_passes/src/replace_qubit_allocation/tests.rs b/source/compiler/qsc_passes/src/replace_qubit_allocation/tests.rs index 2b3e27c86d..453b570fb0 100644 --- a/source/compiler/qsc_passes/src/replace_qubit_allocation/tests.rs +++ b/source/compiler/qsc_passes/src/replace_qubit_allocation/tests.rs @@ -8,7 +8,13 @@ use qsc_data_structures::{ language_features::LanguageFeatures, source::SourceMap, target::TargetCapabilityFlags, }; use qsc_frontend::compile::{self, PackageStore, compile}; -use qsc_hir::{mut_visit::MutVisitor, validate::Validator, visit::Visitor}; +use qsc_hir::{ + hir::{ItemKind, PatKind, SpecBody, StmtKind}, + mut_visit::MutVisitor, + ty::{Prim, Ty}, + validate::Validator, + visit::Visitor, +}; fn check(file: &str, expect: &Expect) { let store = PackageStore::new(compile::core()); @@ -26,6 +32,22 @@ fn check(file: &str, expect: &Expect) { expect.assert_eq(&unit.package.to_string()); } +fn rewrite(file: &str) -> qsc_hir::hir::Package { + let store = PackageStore::new(compile::core()); + let sources = SourceMap::new([("test".into(), file.into())], None); + let mut unit = compile( + &store, + &[], + sources, + TargetCapabilityFlags::all(), + LanguageFeatures::default(), + ); + assert!(unit.errors.is_empty(), "{:?}", unit.errors); + ReplaceQubitAllocation::new(store.core(), &mut unit.assigner).visit_package(&mut unit.package); + Validator::default().visit_package(&unit.package); + unit.package +} + #[test] fn test_single_qubit() { check( @@ -65,6 +87,38 @@ fn test_single_qubit() { ); } +#[test] +fn test_explicitly_annotated_single_qubit_rewrite_preserves_binding_name_and_types() { + let package = rewrite(indoc! { "namespace input { + operation Foo() : Unit { + use q : Qubit = Qubit(); + let x = 3; + } + }" }); + + let callable = package + .items + .values() + .find_map(|item| match &item.kind { + ItemKind::Callable(callable) if callable.name.name.as_ref() == "Foo" => Some(callable), + _ => None, + }) + .expect("Foo callable should exist"); + let SpecBody::Impl(_, block) = &callable.body.body else { + panic!("Foo should have an implementation body"); + }; + let StmtKind::Local(_, pat, expr) = &block.stmts[0].kind else { + panic!("first statement should be the rewritten qubit allocation local"); + }; + + assert_eq!(pat.ty, Ty::Prim(Prim::Qubit)); + assert_eq!(expr.ty, Ty::Prim(Prim::Qubit)); + let PatKind::Bind(ident) = &pat.kind else { + panic!("rewritten qubit allocation should still bind q"); + }; + assert_eq!(ident.name.as_ref(), "q"); +} + #[test] fn test_qubit_array() { check( @@ -730,7 +784,7 @@ fn test_array_expr() { } #[test] -fn test_rtrn_expr() { +fn return_expression_with_nested_qubit_scope_rewrites_correctly() { check( indoc! { "namespace input { operation Foo() : Int { diff --git a/source/compiler/qsc_rca/Cargo.toml b/source/compiler/qsc_rca/Cargo.toml index f254bf46a4..b52e9f54ed 100644 --- a/source/compiler/qsc_rca/Cargo.toml +++ b/source/compiler/qsc_rca/Cargo.toml @@ -23,6 +23,7 @@ thiserror = { workspace = true } [dev-dependencies] expect-test = { workspace = true } qsc = { path = "../qsc" } +qsc_fir_transforms = { path = "../qsc_fir_transforms", features = ["testutil"] } qsc_passes = { path = "../qsc_passes" } [lints] diff --git a/source/compiler/qsc_rca/src/analyzer.rs b/source/compiler/qsc_rca/src/analyzer.rs index 94f9d68cd6..7a045acbff 100644 --- a/source/compiler/qsc_rca/src/analyzer.rs +++ b/source/compiler/qsc_rca/src/analyzer.rs @@ -56,7 +56,10 @@ impl<'a> Analyzer<'a> { // Now we can safely analyze the rest of the items. let core_analyzer = core::Analyzer::new(self.package_store, scaffolding, self.target_capabilities); - core_analyzer.analyze_all().into() + let result: PackageStoreComputeProperties = core_analyzer.analyze_all().into(); + #[cfg(debug_assertions)] + crate::invariants::assert_arity_consistency(self.package_store, &result); + result } #[must_use] @@ -68,6 +71,14 @@ impl<'a> Analyzer<'a> { let scaffolding = cyclic_callables_analyzer.analyze_package(package_id); let core_analyzer = core::Analyzer::new(self.package_store, scaffolding, self.target_capabilities); - core_analyzer.analyze_package(package_id).into() + let result: PackageStoreComputeProperties = + core_analyzer.analyze_package(package_id).into(); + // Note: `analyze_package` is the incremental compiler path (language + // service); the full-store invariant is still valuable for catching + // regressions introduced by incremental updates, so run it here in + // debug builds as well. + #[cfg(debug_assertions)] + crate::invariants::assert_arity_consistency(self.package_store, &result); + result } } diff --git a/source/compiler/qsc_rca/src/applications.rs b/source/compiler/qsc_rca/src/applications.rs index be1a3786af..44ab55f77c 100644 --- a/source/compiler/qsc_rca/src/applications.rs +++ b/source/compiler/qsc_rca/src/applications.rs @@ -318,9 +318,14 @@ impl GeneratorSetsBuilder { inherent: block_inherent_compute_kind, dynamic_param_applications: block_dynamic_param_applications, }; + debug_assert!( + application_generator_set.dynamic_param_applications.len() == input_params_count, + "RCA invariant: block {block_id:?} application generator has {} param applications but callable has {input_params_count} input params", + application_generator_set.dynamic_param_applications.len(), + ); package_compute_properties .blocks - .insert(block_id, application_generator_set); + .insert_if_absent(block_id, application_generator_set); } // Save an applications generator set for each statement using their compute properties. @@ -340,9 +345,14 @@ impl GeneratorSetsBuilder { inherent: stmt_inherent_compute_kind, dynamic_param_applications: stmt_dynamic_param_applications, }; + debug_assert!( + application_generator_set.dynamic_param_applications.len() == input_params_count, + "RCA invariant: stmt {stmt_id:?} application generator has {} param applications but callable has {input_params_count} input params", + application_generator_set.dynamic_param_applications.len(), + ); package_compute_properties .stmts - .insert(stmt_id, application_generator_set); + .insert_if_absent(stmt_id, application_generator_set); } // Save an applications generator set for each expression using their compute properties. @@ -362,9 +372,14 @@ impl GeneratorSetsBuilder { inherent: expr_inherent_compute_kind, dynamic_param_applications: expr_dynamic_param_applications, }; + debug_assert!( + application_generator_set.dynamic_param_applications.len() == input_params_count, + "RCA invariant: expr {expr_id:?} application generator has {} param applications but callable has {input_params_count} input params", + application_generator_set.dynamic_param_applications.len(), + ); package_compute_properties .exprs - .insert(expr_id, application_generator_set); + .insert_if_absent(expr_id, application_generator_set); } // Save the unresolved callee expressions. diff --git a/source/compiler/qsc_rca/src/core.rs b/source/compiler/qsc_rca/src/core.rs index 30f0b7da32..7a9ea97b2a 100644 --- a/source/compiler/qsc_rca/src/core.rs +++ b/source/compiler/qsc_rca/src/core.rs @@ -1277,7 +1277,11 @@ impl<'a> Analyzer<'a> { } fn analyze_spec(&mut self, id: GlobalSpecId, callable_decl: &'a CallableDecl) { - // Only do this if the specialization has not been analyzed already. + // Early-return: skip re-analysis of already-analyzed specializations. + // With insert-if-absent at the scaffolding level, this guard is no longer + // required for overwrite correctness, but it remains necessary for: + // 1. Cycle prevention in cyclic callable analysis + // 2. Performance (avoids redundant analysis of already-complete specs) if self .package_store_compute_properties .find_specialization(id) @@ -1569,7 +1573,7 @@ impl<'a> Analyzer<'a> { fn unanalyzed_stmts(&self, package_id: PackageId) -> Vec { let package = self.package_store.get(package_id); let mut unanalyzed_stmts = Vec::new(); - for (stmt_id, _) in &package.stmts { + for (stmt_id, _stmt) in &package.stmts { if self .package_store_compute_properties .find_stmt((package_id, stmt_id).into()) @@ -1736,10 +1740,13 @@ impl<'a> Analyzer<'a> { let current_package = self.package_store.get(self.get_current_package_id()); let mut stmt_collector = StmtCollector::new(current_package); stmt_collector.visit_block(block_id); + let callable_context = self.get_current_item_context().get_callable_context(); + let default_generator_set = + default_application_generator_set_for_callable(callable_context); for stmt_id in stmt_collector.stmts { self.package_store_compute_properties.insert_stmt( (self.get_current_package_id(), stmt_id).into(), - ApplicationGeneratorSet::default(), + default_generator_set.clone(), ); } } @@ -2250,6 +2257,42 @@ enum CallComputeKind { Override(ComputeKind), } +/// Builds a neutral, arity-matched `ApplicationGeneratorSet` for a callable whose body +/// statements are being marked as "visited" without analysis (e.g. `@SimulatableIntrinsic` +/// and `@Test` callable bodies in `set_all_stmts_in_block_to_default`). +/// +/// The generator set must have `dynamic_param_applications` whose length matches the +/// owning callable's input-parameter arity so the invariant check in +/// `invariants.rs` (and any downstream consumer of these sentinel stmts) does not see a +/// zero-arity entry where a non-zero arity is expected. Each entry is a conservative +/// neutral shape: scalar parameters map to `ParamApplication::Element(ComputeKind::Static)` +/// and array parameters map to `ParamApplication::Array` with a `Static` static-size +/// compute kind and a conservative `Dynamic` dynamic-size compute kind. +fn default_application_generator_set_for_callable( + callable_context: &CallableContext, +) -> ApplicationGeneratorSet { + let mut dynamic_param_applications = + Vec::::with_capacity(callable_context.input_params.len()); + for param in &callable_context.input_params { + let param_application = match ¶m.ty { + Ty::Array(_) => ParamApplication::Array(ArrayParamApplication { + static_size: ComputeKind::Static, + dynamic_size: ComputeKind::Dynamic { + runtime_features: RuntimeFeatureFlags::UseOfDynamicallySizedArray, + value_kind: ValueKind::Variable, + }, + }), + _ => ParamApplication::Element(ComputeKind::Static), + }; + dynamic_param_applications.push(param_application); + } + + ApplicationGeneratorSet { + inherent: ComputeKind::Static, + dynamic_param_applications, + } +} + fn derive_intrinsic_function_application_generator_set( callable_context: &CallableContext, ) -> ApplicationGeneratorSet { diff --git a/source/compiler/qsc_rca/src/invariants.rs b/source/compiler/qsc_rca/src/invariants.rs new file mode 100644 index 0000000000..4f8f620683 --- /dev/null +++ b/source/compiler/qsc_rca/src/invariants.rs @@ -0,0 +1,196 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! Debug-only invariant checks for RCA results. +//! +//! This module provides a post-walk (`assert_arity_consistency`) that verifies +//! every `ApplicationGeneratorSet` recorded in a `PackageStoreComputeProperties` +//! has `dynamic_param_applications.len()` matching the arity (i.e., the number +//! of flattened input parameters) of its owning callable specialization, or +//! `0` for top-level statements and entry expressions. +//! +//! The module is gated on `#[cfg(debug_assertions)]` so release builds compile +//! it out entirely. + +use crate::{ApplicationGeneratorSet, PackageStoreComputeProperties}; +use qsc_fir::{ + fir::{ + Block, BlockId, CallableImpl, Expr, ExprId, ItemKind, Package, PackageId, PackageStore, + Pat, PatId, SpecDecl, Stmt, StmtId, + }, + visit::{self, Visitor}, +}; +use rustc_hash::FxHashMap; + +/// Walks `store` and `props` and asserts that every recorded +/// `ApplicationGeneratorSet.dynamic_param_applications` vector has the arity +/// of its owning specialization (or `0` for top-level statements and entry +/// expressions). +/// +/// Every package in the store is checked. Entries whose ownership cannot be +/// resolved from the FIR walk are silently skipped (see [`check_entry`]). +pub(crate) fn assert_arity_consistency( + store: &PackageStore, + props: &PackageStoreComputeProperties, +) { + for (package_id, package) in store { + let ownership = collect_ownership(package_id, package); + let package_props = props.get(package_id); + + for (block_id, generator) in package_props.blocks.iter() { + check_entry( + package_id, + ElementKey::Block(block_id), + generator, + &ownership, + ); + } + for (stmt_id, generator) in package_props.stmts.iter() { + check_entry(package_id, ElementKey::Stmt(stmt_id), generator, &ownership); + } + for (expr_id, generator) in package_props.exprs.iter() { + check_entry(package_id, ElementKey::Expr(expr_id), generator, &ownership); + } + } +} + +#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)] +enum ElementKey { + Block(BlockId), + Stmt(StmtId), + Expr(ExprId), +} + +fn check_entry( + package_id: PackageId, + key: ElementKey, + generator: &ApplicationGeneratorSet, + ownership: &FxHashMap, +) { + let Some(&expected) = ownership.get(&key) else { + // Unknown ownership is silently tolerated: this indicates either a + // synthesized element not contributing to RCA, or a gap that a + // future invariant refinement should cover. + return; + }; + let actual = generator.dynamic_param_applications.len(); + debug_assert!( + actual == expected, + "RCA invariant: package {package_id:?} {key:?} application generator has {actual} \ + param applications but owning specialization has arity {expected}", + ); +} + +fn collect_ownership(package_id: PackageId, package: &Package) -> FxHashMap { + let mut collector = OwnershipCollector { + package, + map: FxHashMap::default(), + current_arity: 0, + }; + + // Walk each callable item so spec-owned IDs are recorded with the + // callable's input-pat arity. Top-level statements and the entry + // expression are recorded after item walks with arity 0. + for (_, item) in &package.items { + if let ItemKind::Callable(callable) = &item.kind { + let arity = package.derive_callable_input_params(callable).len(); + collector.current_arity = arity; + match &callable.implementation { + CallableImpl::Spec(spec_impl) => { + collector.visit_spec_decl(&spec_impl.body); + if let Some(spec) = spec_impl.adj.as_ref() { + collector.visit_spec_decl(spec); + } + if let Some(spec) = spec_impl.ctl.as_ref() { + collector.visit_spec_decl(spec); + } + if let Some(spec) = spec_impl.ctl_adj.as_ref() { + collector.visit_spec_decl(spec); + } + } + // `SimulatableIntrinsic` bodies are not analyzed by the core + // analyzer; their stmts receive an arity-matched default + // generator set via `core::set_all_stmts_in_block_to_default`. + // Record ownership at the callable's arity so the invariant + // sees a consistent view. + CallableImpl::SimulatableIntrinsic(spec_decl) => { + collector.visit_spec_decl(spec_decl); + } + // `Intrinsic` callables have no body to walk. + CallableImpl::Intrinsic => {} + } + } + } + + // Top-level stmts + entry expr live outside any spec and have arity 0. + collector.current_arity = 0; + for (stmt_id, _) in &package.stmts { + collector.map.entry(ElementKey::Stmt(stmt_id)).or_insert(0); + } + if let Some(entry_expr) = package.entry { + collector + .map + .entry(ElementKey::Expr(entry_expr)) + .or_insert(0); + // Walk the entry expression tree so nested exprs/blocks/stmts are + // captured too. + collector.visit_expr(entry_expr); + } + + let _ = package_id; // Kept for signature symmetry / future diagnostics. + collector.map +} + +struct OwnershipCollector<'a> { + package: &'a Package, + map: FxHashMap, + current_arity: usize, +} + +impl<'a> Visitor<'a> for OwnershipCollector<'a> { + fn get_block(&self, id: BlockId) -> &'a Block { + self.package.blocks.get(id).expect("block should exist") + } + + fn get_expr(&self, id: ExprId) -> &'a Expr { + self.package.exprs.get(id).expect("expr should exist") + } + + fn get_pat(&self, id: PatId) -> &'a Pat { + self.package.pats.get(id).expect("pat should exist") + } + + fn get_stmt(&self, id: StmtId) -> &'a Stmt { + self.package.stmts.get(id).expect("stmt should exist") + } + + fn visit_block(&mut self, id: BlockId) { + // First-wins insertion prevents a later arity-0 entry-expression + // walk from clobbering a spec-body arity recorded by the earlier + // item walk. The sharing case is dormant today but this hardening + // removes a latent aliasing hazard at zero cost. + self.map + .entry(ElementKey::Block(id)) + .or_insert(self.current_arity); + visit::walk_block(self, id); + } + + fn visit_stmt(&mut self, id: StmtId) { + self.map + .entry(ElementKey::Stmt(id)) + .or_insert(self.current_arity); + visit::walk_stmt(self, id); + } + + fn visit_expr(&mut self, id: ExprId) { + self.map + .entry(ElementKey::Expr(id)) + .or_insert(self.current_arity); + visit::walk_expr(self, id); + } + + fn visit_spec_decl(&mut self, decl: &'a SpecDecl) { + // Skip pat to avoid recording pattern IDs (we only track blocks/stmts/exprs). + self.visit_block(decl.block); + } +} diff --git a/source/compiler/qsc_rca/src/lib.rs b/source/compiler/qsc_rca/src/lib.rs index 0dee514c99..8f0bfb6507 100644 --- a/source/compiler/qsc_rca/src/lib.rs +++ b/source/compiler/qsc_rca/src/lib.rs @@ -16,6 +16,8 @@ mod core; mod cycle_detection; mod cyclic_callables; pub mod errors; +#[cfg(debug_assertions)] +mod invariants; mod overrider; mod scaffolding; @@ -352,7 +354,15 @@ impl ApplicationGeneratorSet { &self, args_compute_kinds: &[ComputeKind], ) -> ComputeKind { - assert!(self.dynamic_param_applications.len() == args_compute_kinds.len()); + // RCA generators record one `ParamApplication` per flattened input + // parameter of the owning callable. The runtime arg vector must match + // exactly; any skew indicates a bug in the analyzer's recording path. + assert!( + self.dynamic_param_applications.len() == args_compute_kinds.len(), + "application generator recorded {} parameter applications for {} runtime arguments", + self.dynamic_param_applications.len(), + args_compute_kinds.len() + ); let mut compute_kind = self.inherent; for (arg_compute_kind, param_application) in args_compute_kinds .iter() diff --git a/source/compiler/qsc_rca/src/overrider.rs b/source/compiler/qsc_rca/src/overrider.rs index cafad3c691..cab26d67a5 100644 --- a/source/compiler/qsc_rca/src/overrider.rs +++ b/source/compiler/qsc_rca/src/overrider.rs @@ -8,8 +8,8 @@ use crate::{ }; use qsc_fir::{ fir::{ - Block, BlockId, CallableImpl, Expr, ExprId, Global, Item, ItemKind, LocalItemId, Package, - PackageStore, PackageStoreLookup, Pat, PatId, Stmt, StmtId, + Block, BlockId, CallableImpl, Expr, ExprId, Global, ItemKind, Package, PackageStore, + PackageStoreLookup, Pat, PatId, Stmt, StmtId, }, ty::FunctorSetValue, visit::{Visitor, walk_block, walk_expr, walk_stmt}, @@ -94,15 +94,6 @@ impl<'a> Overrider<'a> { .expect("current package should be valid") } - fn get_item(&self, id: LocalItemId) -> &'a Item { - let package_id = self.get_current_package(); - self.package_store - .get(package_id) - .items - .get(id) - .expect("item not found") - } - fn populate_package_internal(&mut self, package_id: PackageId, package: &'a Package) { self.current_package = Some(package_id); self.visit_package(package, self.package_store); @@ -201,7 +192,8 @@ impl<'a> Visitor<'a> for Overrider<'a> { let callables = namespace_items .iter() .filter_map(|item_id| { - let item = self.get_item(*item_id); + let package_id = self.get_current_package(); + let item = self.package_store.get(package_id).items.get(*item_id)?; match &item.kind { ItemKind::Callable(decl) => Some((item.id, decl.name.name.to_string())), _ => None, diff --git a/source/compiler/qsc_rca/src/scaffolding.rs b/source/compiler/qsc_rca/src/scaffolding.rs index ed3885f87d..83d3388fcd 100644 --- a/source/compiler/qsc_rca/src/scaffolding.rs +++ b/source/compiler/qsc_rca/src/scaffolding.rs @@ -153,11 +153,15 @@ impl InternalPackageStoreComputeProperties { } pub fn insert_block(&mut self, id: StoreBlockId, value: ApplicationGeneratorSet) { - self.get_mut(id.package).blocks.insert(id.block, value); + self.get_mut(id.package) + .blocks + .insert_if_absent(id.block, value); } pub fn insert_expr(&mut self, id: StoreExprId, value: ApplicationGeneratorSet) { - self.get_mut(id.package).exprs.insert(id.expr, value); + self.get_mut(id.package) + .exprs + .insert_if_absent(id.expr, value); } pub fn insert_item(&mut self, id: StoreItemId, value: InternalItemComputeProperties) { @@ -171,7 +175,8 @@ impl InternalPackageStoreComputeProperties { item_compute_properties { // The item already exists but not the specialization. - specializations.insert(SpecializationIndex::from(id.functor_set_value), value); + specializations + .insert_if_absent(SpecializationIndex::from(id.functor_set_value), value); } else { panic!("item should be a callable"); } @@ -187,7 +192,9 @@ impl InternalPackageStoreComputeProperties { } pub fn insert_stmt(&mut self, id: StoreStmtId, value: ApplicationGeneratorSet) { - self.get_mut(id.package).stmts.insert(id.stmt, value); + self.get_mut(id.package) + .stmts + .insert_if_absent(id.stmt, value); } } diff --git a/source/compiler/qsc_rca/src/tests.rs b/source/compiler/qsc_rca/src/tests.rs index 2057229e68..1ecb02e6bb 100644 --- a/source/compiler/qsc_rca/src/tests.rs +++ b/source/compiler/qsc_rca/src/tests.rs @@ -10,11 +10,13 @@ mod calls; mod cycles; mod ifs; mod intrinsics; +mod invariants_strict; mod lambdas; mod loops; mod measurements; mod overrides; mod qubits; +mod return_unify_interactions; mod strings; mod structs; mod types; @@ -101,6 +103,70 @@ impl Default for CompilationContext { } } +/// A fixture that mirrors [`CompilationContext`] but runs the FIR transform +/// pipeline over the lowered FIR store before instantiating the RCA +/// [`Analyzer`]. Used by arity-consistency and return-unify interaction tests +/// that need RCA results to reflect post-pipeline FIR. +/// +/// The pipeline requires an executable package (with an entry expression), so +/// this fixture compiles Q# source plus an explicit entry string via +/// [`qsc_fir_transforms::test_utils::compile_to_fir_with_entry`]. +pub struct PipelineContext { + pub fir_store: PackageStore, + pub user_package_id: qsc_fir::fir::PackageId, + pub compute_properties: PackageStoreComputeProperties, +} + +impl PipelineContext { + /// Builds a pipeline context from Q# `source` and an executable `entry` + /// expression. + #[must_use] + pub fn new(source: &str, entry: &str, capabilities: TargetCapabilityFlags) -> Self { + let (mut fir_store, user_package_id) = + qsc_fir_transforms::test_utils::compile_to_fir_with_entry(source, entry); + // CONTRACT: On success, `run_pipeline` produces FIR satisfying `InvariantLevel::PostAll`. + // The RCA `Analyzer` assumes PostAll invariants hold — in particular, no closures or + // unresolved type parameters remain in reachable code. See + // `qsc_fir_transforms::invariants::check` for the authoritative checker. + let errors = qsc_fir_transforms::run_pipeline(&mut fir_store, user_package_id); + assert!( + errors.is_empty(), + "FIR transform pipeline reported errors: {errors:?}" + ); + let analyzer = Analyzer::init(&fir_store, capabilities); + let compute_properties = analyzer.analyze_all(); + Self { + fir_store, + user_package_id, + compute_properties, + } + } + + #[must_use] + pub fn get_compute_properties(&self) -> &PackageStoreComputeProperties { + &self.compute_properties + } +} + +impl Default for PipelineContext { + fn default() -> Self { + Self::new("", "()", TargetCapabilityFlags::all()) + } +} + +#[test] +fn pipeline_context_smoke() { + let context = PipelineContext::default(); + assert!( + context.get_compute_properties().iter().count() > 0, + "pipeline context should produce compute properties for at least one package", + ); + let _ = context.fir_store.get(context.user_package_id); + let _ = context + .get_compute_properties() + .get(context.user_package_id); +} + pub trait PackageStoreSearch { fn find_callable_id_by_name(&self, name: &str) -> Option; } diff --git a/source/compiler/qsc_rca/src/tests/invariants_strict.rs b/source/compiler/qsc_rca/src/tests/invariants_strict.rs new file mode 100644 index 0000000000..61c2748117 --- /dev/null +++ b/source/compiler/qsc_rca/src/tests/invariants_strict.rs @@ -0,0 +1,282 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! Strict-invariant regression tests for the default-generator-set correction. +//! +//! Each test exercises a code shape that was a known source of +//! `actual==0 && expected>0` arity skew prior to the default-generator-set +//! correction and +//! relied on the old tolerance rule in +//! `invariants::check_entry`. With the tolerance removed and strict `==` +//! enforced both in `lib.rs` (`generate_application_compute_kind`) and +//! `invariants.rs` (`debug_assert!(actual == expected, ..)`), any future +//! regression that re-introduces arity-0 saves over spec-owned stmts/blocks +//! will panic here in debug builds. +//! +//! Each test: +//! 1. Runs the RCA pipeline to completion via `CompilationContext` (or +//! `PipelineContext` when post-FIR-transform behavior is being covered). +//! Both contexts call into `Analyzer::analyze_all` / `analyze_package`, +//! which runs `assert_arity_consistency` under `#[cfg(debug_assertions)]`. +//! 2. Adds an explicit positive arity check on a representative callable so a +//! silent regression (invariant disabled or weakened) still surfaces as a +//! test failure. + +use qsc_data_structures::target::Profile; + +use super::{CompilationContext, PackageStoreSearch, PipelineContext}; +use crate::{ComputePropertiesLookup, ItemComputeProperties}; + +/// Returns the `dynamic_param_applications` length recorded for the body spec +/// of the callable named `callable_name` in `context`. +fn body_arity(context: &CompilationContext, callable_name: &str) -> usize { + let id = context + .fir_store + .find_callable_id_by_name(callable_name) + .unwrap_or_else(|| panic!("callable {callable_name} should exist")); + let ItemComputeProperties::Callable(props) = context.get_compute_properties().get_item(id) + else { + panic!("{callable_name} should be a callable item"); + }; + props.body.dynamic_param_applications.len() +} + +/// Class 1 (arity 1): `@SimulatableIntrinsic` operation whose body stmts are +/// written via `set_all_stmts_in_block_to_default`. Under the old +/// `ApplicationGeneratorSet::default()` writes, every stmt in the body was +/// saved at arity 0; the debug invariant reported expected arity 1 and +/// tolerated the skew. The default-generator-set correction now saves +/// arity-matched generators directly. +#[test] +fn simulatable_intrinsic_arity_one_body_matches_input_params() { + let mut context = CompilationContext::default(); + context.update( + r#" + @SimulatableIntrinsic() + operation SimIntrinsic1(q : Qubit) : Unit { + H(q); + let x = 1; + Message($"x = {x}"); + }"#, + ); + assert_eq!( + body_arity(&context, "SimIntrinsic1"), + 1, + "SimulatableIntrinsic body arity must match input-pat arity", + ); +} + +/// Class 1 (arity 2): same as above with a two-parameter input pat. +#[test] +fn simulatable_intrinsic_arity_two_body_matches_input_params() { + let mut context = CompilationContext::default(); + context.update( + r#" + @SimulatableIntrinsic() + operation SimIntrinsic2(q : Qubit, i : Int) : Unit { + H(q); + let y = i + 1; + Message($"y = {y}"); + }"#, + ); + assert_eq!( + body_arity(&context, "SimIntrinsic2"), + 2, + "SimulatableIntrinsic body arity must match input-pat arity", + ); +} + +/// Class 1 (arity 3, mixed scalar/array): covers the `ParamApplication::Array` +/// construction path inside `default_application_generator_set_for_callable`. +#[test] +fn simulatable_intrinsic_arity_three_with_array_param_body_matches_input_params() { + let mut context = CompilationContext::default(); + context.update( + r#" + @SimulatableIntrinsic() + operation SimIntrinsic3(q : Qubit, i : Int, arr : Int[]) : Unit { + H(q); + let z = i + Length(arr); + Message($"z = {z}"); + }"#, + ); + assert_eq!( + body_arity(&context, "SimIntrinsic3"), + 3, + "SimulatableIntrinsic body arity must match input-pat arity", + ); +} + +/// Class 2: `@Test` callable with a non-trivial measurement-driven body. +/// Previously the body stmts were saved at arity 0 by the top-level sweep +/// (`@Test` bodies are not entered by the main analyzer path). The body is +/// arity 0 because `@Test` callables take no parameters, but the regression +/// target here is that the invariant runs to completion on a `@Test` body +/// without triggering any intermediate skew on inner stmts/blocks. +#[test] +fn test_attribute_callable_body_reaches_strict_invariant() { + let mut context = CompilationContext::default(); + context.update( + r#" + @Test() + operation TestSample() : Int { + use q = Qubit(); + mutable a = 0; + if M(q) == Zero { + set a = 1; + } + Message($"a = {a}"); + return a; + }"#, + ); + assert_eq!( + body_arity(&context, "TestSample"), + 0, + "@Test callable body arity must match the empty input pat", + ); +} + +/// End-to-end fixture: a minimal reduction of `samples/algorithms/DeutschJozsa.qs` +/// exercising multiple callables, a dynamic measurement loop, and an array +/// parameter. This is Class 3 coverage — prior to the narrowing of +/// `unanalyzed_stmts`, the top-level sweep would overwrite spec-body stmts at +/// arity 0 for programs of this shape. +#[test] +fn deutsch_jozsa_shape_passes_strict_invariant() { + let mut context = CompilationContext::default(); + context.update( + r#" + operation ConstantOracle(qs : Qubit[], target : Qubit) : Unit is Adj + Ctl { + body ... { } + adjoint self; + } + + operation BalancedOracle(qs : Qubit[], target : Qubit) : Unit is Adj + Ctl { + body ... { + for q in qs { + CNOT(q, target); + } + } + } + + operation DeutschJozsaMini(oracle : (Qubit[], Qubit) => Unit is Adj + Ctl, n : Int) : Bool { + use qs = Qubit[n]; + use target = Qubit(); + X(target); + H(target); + for q in qs { + H(q); + } + oracle(qs, target); + for q in qs { + H(q); + } + mutable isConstant = true; + for q in qs { + if M(q) == One { + set isConstant = false; + } + } + Reset(target); + ResetAll(qs); + return isConstant; + } + + operation MainMini() : Bool[] { + [ + DeutschJozsaMini(ConstantOracle, 3), + DeutschJozsaMini(BalancedOracle, 3) + ] + }"#, + ); + assert_eq!( + body_arity(&context, "DeutschJozsaMini"), + 2, + "DeutschJozsaMini takes (oracle, n) — body arity must be 2", + ); + assert_eq!( + body_arity(&context, "MainMini"), + 0, + "MainMini has no input parameters — body arity must be 0", + ); + assert_eq!( + body_arity(&context, "ConstantOracle"), + 2, + "ConstantOracle takes (qs, target) — body arity must be 2", + ); +} + +/// Mutual recursion: cyclic callables are analyzed by the dedicated +/// `cyclic_callables::Analyzer` pass, which pre-populates spec-body +/// generators at arity N. Historically the subsequent `TopLevelContext` +/// sweep could overwrite these at arity 0 when a cyclic spec-body stmt was +/// not tracked as "already analyzed". Phase 2's spec-owned-stmt filter +/// prevents the overwrite; this test guards against a regression. +#[test] +fn mutual_recursion_passes_strict_invariant() { + let mut context = CompilationContext::default(); + context.update( + r#" + function Ping(n : Int) : Int { + if n <= 0 { + return 0; + } + return Pong(n - 1); + } + + function Pong(n : Int) : Int { + if n <= 0 { + return 0; + } + return Ping(n - 1); + }"#, + ); + assert_eq!( + body_arity(&context, "Ping"), + 1, + "Ping body arity must match its single Int input parameter", + ); + assert_eq!( + body_arity(&context, "Pong"), + 1, + "Pong body arity must match its single Int input parameter", + ); +} + +/// Dynamic return via an early-exit inside a measurement-driven branch. This +/// exercises the `return_unify` FIR pass. Uses `PipelineContext` to force the +/// FIR transform pipeline (including GC) to run before RCA. +#[test] +fn dynamic_return_pipeline_passes_strict_invariant() { + let source = r#" + namespace Test { + operation DynReturnStrict(qs : Qubit[]) : Result[] { + mutable results = [Zero, size = Length(qs)]; + mutable i = 0; + while i < Length(qs) { + if M(qs[i]) == One { + return results; + } + set i += 1; + } + results + } + } + "#; + let entry = "{ use qs = Qubit[2]; Test.DynReturnStrict(qs) }"; + let context = PipelineContext::new(source, entry, Profile::AdaptiveRIF.into()); + let dyn_return_id = context + .fir_store + .find_callable_id_by_name("DynReturnStrict") + .expect("DynReturnStrict should exist after pipeline lowering"); + let ItemComputeProperties::Callable(props) = + context.get_compute_properties().get_item(dyn_return_id) + else { + panic!("DynReturnStrict should be a callable item"); + }; + assert_eq!( + props.body.dynamic_param_applications.len(), + 1, + "DynReturnStrict body arity must match its single Qubit[] input parameter", + ); +} diff --git a/source/compiler/qsc_rca/src/tests/return_unify_interactions.rs b/source/compiler/qsc_rca/src/tests/return_unify_interactions.rs new file mode 100644 index 0000000000..7639237525 --- /dev/null +++ b/source/compiler/qsc_rca/src/tests/return_unify_interactions.rs @@ -0,0 +1,73 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! Tests that exercise RCA behavior in the presence of FIR transforms that +//! desugar `return` (arity-consistency / return-unify interaction coverage). The `return_unify` pass introduces a +//! synthetic flag-based early-return when a `return` appears inside a dynamic +//! scope (e.g. `if M(q) == One { return ... }`). Historically this interacted +//! badly with RCA's `dynamic_param_applications` arity invariants; the +//! `assert_arity_consistency` post-walker (see +//! `source/compiler/qsc_rca/src/invariants.rs`) now runs in debug builds at +//! the end of `Analyzer::analyze_all` / `Analyzer::analyze_package` to catch +//! skew regressions. + +use qsc_data_structures::target::Profile; + +use super::{PackageStoreSearch, PipelineContext}; +use crate::{ComputeKind, ComputePropertiesLookup, ItemComputeProperties, ValueKind}; + +/// Return-unify regression: after the return-unification pass rewrites a dynamic-scope +/// `return` into a flag-based fallback, RCA must produce a coherent +/// `ApplicationGeneratorSet` for the enclosing callable's body spec. The +/// measurement-driven dynamism guarantees the value kind is `Variable`. +/// +/// Regression note: the implicit arity-consistency invariant is enforced by +/// `PipelineContext::new`, which invokes `Analyzer::analyze_all` and therefore +/// runs `assert_arity_consistency` on the user package. Reverting the +/// arity-consistency invariant (or regressing the return-unify pass so arities diverge from +/// `CallableImpl` input counts) would flip that implicit assertion into a skew +/// panic before the explicit `ComputeKind` check below is reached. +#[test] +fn flag_fallback_value_kind_after_dynamic_scope_return() { + let source = r#" + namespace Test { + operation DynReturn(qs : Qubit[]) : Result[] { + mutable results = [Zero, size = Length(qs)]; + mutable i = 0; + while i < Length(qs) { + if M(qs[i]) == One { + return results; + } + set i += 1; + } + results + } + } + "#; + let entry = "{ use qs = Qubit[2]; Test.DynReturn(qs) }"; + + let context = PipelineContext::new(source, entry, Profile::AdaptiveRIF.into()); + + let dyn_return_id = context + .fir_store + .find_callable_id_by_name("DynReturn") + .expect("DynReturn callable should exist after pipeline lowering"); + + let item_props = context.get_compute_properties().get_item(dyn_return_id); + let ItemComputeProperties::Callable(callable_props) = item_props else { + panic!("DynReturn should be a callable item, got non-callable compute properties"); + }; + + match callable_props.body.inherent { + ComputeKind::Dynamic { value_kind, .. } => { + assert_eq!( + value_kind, + ValueKind::Variable, + "DynReturn body should be classified as Dynamic/Variable after the flag-fallback rewrite", + ); + } + ComputeKind::Static => { + panic!("DynReturn body should be Dynamic after measurement-driven return, got Static",); + } + } +} diff --git a/source/compiler/qsc_rir/src/passes/insert_alloca_load.rs b/source/compiler/qsc_rir/src/passes/insert_alloca_load.rs index 1ef4077ee8..e18004100c 100644 --- a/source/compiler/qsc_rir/src/passes/insert_alloca_load.rs +++ b/source/compiler/qsc_rir/src/passes/insert_alloca_load.rs @@ -127,6 +127,7 @@ fn add_alloca_load_to_block( | Instruction::Fsub(lhs, rhs, _) | Instruction::Fmul(lhs, rhs, _) | Instruction::Fdiv(lhs, rhs, _) + | Instruction::Frem(lhs, rhs, _) | Instruction::Fcmp(_, lhs, rhs, _) | Instruction::Icmp(_, lhs, rhs, _) | Instruction::LogicalAnd(lhs, rhs, _) diff --git a/source/compiler/qsc_rir/src/passes/prune_unneeded_stores.rs b/source/compiler/qsc_rir/src/passes/prune_unneeded_stores.rs index 280186cc18..c6895f7a49 100644 --- a/source/compiler/qsc_rir/src/passes/prune_unneeded_stores.rs +++ b/source/compiler/qsc_rir/src/passes/prune_unneeded_stores.rs @@ -144,6 +144,7 @@ fn check_var_usage( | Instruction::Fsub(operand0, operand1, variable) | Instruction::Fmul(operand0, operand1, variable) | Instruction::Fdiv(operand0, operand1, variable) + | Instruction::Frem(operand0, operand1, variable) | Instruction::LogicalAnd(operand0, operand1, variable) | Instruction::LogicalOr(operand0, operand1, variable) | Instruction::BitwiseAnd(operand0, operand1, variable) diff --git a/source/compiler/qsc_rir/src/passes/ssa_check.rs b/source/compiler/qsc_rir/src/passes/ssa_check.rs index 5f3c453503..248655dd54 100644 --- a/source/compiler/qsc_rir/src/passes/ssa_check.rs +++ b/source/compiler/qsc_rir/src/passes/ssa_check.rs @@ -133,6 +133,8 @@ fn get_variable_uses(program: &Program) -> IndexMap IndexMap IndexMap { write_binary_instruction(f, "Fdiv", lhs, rhs, *variable)?; } + Self::Frem(lhs, rhs, variable) => { + write_binary_instruction(f, "Frem", lhs, rhs, *variable)?; + } Self::Fcmp(op, lhs, rhs, variable) => { write_fcmp_instruction(f, *op, lhs, rhs, *variable)?; } diff --git a/source/compiler/qsc_rir/src/utils.rs b/source/compiler/qsc_rir/src/utils.rs index 6c777b8ca9..5bbad5d5bf 100644 --- a/source/compiler/qsc_rir/src/utils.rs +++ b/source/compiler/qsc_rir/src/utils.rs @@ -90,6 +90,7 @@ pub fn get_variable_assignments(program: &Program) -> IndexMap, V> IndexMap { self.values[index] = Some(value); } + /// Inserts a value at the given index only if no value is already present. + /// Returns `true` if the value was inserted, `false` if a value already existed. + pub fn insert_if_absent(&mut self, key: K, value: V) -> bool { + let index = key.into(); + if index >= self.values.len() { + self.values.resize_with(index + 1, || None); + } + if self.values[index].is_none() { + self.values[index] = Some(value); + true + } else { + false + } + } + pub fn contains_key(&self, key: K) -> bool { let index: usize = key.into(); self.values.get(index).is_some_and(Option::is_some) diff --git a/source/index_map/src/tests.rs b/source/index_map/src/tests.rs new file mode 100644 index 0000000000..43c68ca5b9 --- /dev/null +++ b/source/index_map/src/tests.rs @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use super::*; + +#[test] +fn insert_if_absent_into_empty_returns_true() { + let mut map: IndexMap = IndexMap::new(); + assert!(map.insert_if_absent(0, 42)); + assert_eq!(*map.get(0).expect("IndexMap::get: index out of bounds"), 42); +} + +#[test] +fn insert_if_absent_occupied_returns_false_preserves_original() { + let mut map: IndexMap = IndexMap::new(); + map.insert(0, 42); + assert!(!map.insert_if_absent(0, 99)); + assert_eq!(*map.get(0).expect("IndexMap::get: index out of bounds"), 42); +} + +#[test] +fn insert_if_absent_extends_capacity_for_sparse_key() { + let mut map: IndexMap = IndexMap::new(); + assert!(map.insert_if_absent(100, 7)); + assert_eq!( + *map.get(100).expect("IndexMap::get: index out of bounds"), + 7 + ); + assert!(!map.contains_key(0)); +} diff --git a/source/language_service/src/compilation.rs b/source/language_service/src/compilation.rs index 802a26ee9f..bcc0748685 100644 --- a/source/language_service/src/compilation.rs +++ b/source/language_service/src/compilation.rs @@ -252,7 +252,7 @@ impl Compilation { ); let res = qsc::openqasm::semantic::parse_sources(&sources); let unit = compile_to_qsharp_ast_with_config(res, config); - let target_profile = unit.profile(); + let target_profile = unit.profile().unwrap_or(Profile::Unrestricted); let CompileRawQasmResult(store, source_package_id, _, _sig, mut compile_errors) = qsc::openqasm::compile_openqasm(unit, package_type); @@ -413,7 +413,34 @@ fn run_fir_passes( return; } - let (fir_store, fir_package_id) = qsc::lower_hir_to_fir(package_store, package_id); + let (mut fir_store, fir_package_id, _assigner) = + qsc::lower_hir_to_fir(package_store, package_id); + + // Run FIR transforms (monomorphize, defunctionalize, etc.) before capability checking. + // This matches the codegen pipeline ordering in qsc/src/codegen.rs. + // The transforms require an entry expression (defunctionalize uses reachability from entry), + // so only run when the package has one. + if fir_store.get(fir_package_id).entry.is_some() { + // CONTRACT: On success, `run_pipeline` produces FIR satisfying `InvariantLevel::PostAll`: + // - No `Ty::Param` in reachable code (monomorphization completed). + // - No `ExprKind::Return` in reachable code (return unification completed). + // - No `Ty::Arrow` params / `ExprKind::Closure` (defunctionalization completed). + // - No `Ty::Udt` / `ExprKind::Struct` / `Field::Path` (UDT erasure completed). + // - All exec-graph ranges populated (exec-graph rebuild completed). + // RCA (capability checking) assumes these invariants hold. See + // `qsc_fir_transforms::invariants::check` for the authoritative checker. + let transform_errors = qsc::fir_transforms::run_pipeline(&mut fir_store, fir_package_id); + if !transform_errors.is_empty() { + for err in transform_errors { + errors.push(WithSource::from_map( + &unit.sources, + compile::ErrorKind::FirTransform(err), + )); + } + return; // Don't run RCA on invalid FIR + } + } + let caps_results = PassContext::run_fir_passes_on_fir(&fir_store, fir_package_id, target_profile.into()); if let Err(caps_errors) = caps_results { diff --git a/source/npm/qsharp/test/circuits-cases/lambda.qs.snapshot.html b/source/npm/qsharp/test/circuits-cases/lambda.qs.snapshot.html index 02fa063825..b3f163720a 100644 --- a/source/npm/qsharp/test/circuits-cases/lambda.qs.snapshot.html +++ b/source/npm/qsharp/test/circuits-cases/lambda.qs.snapshot.html @@ -80,7 +80,7 @@ - lambda.qs:3:24 let lambda = (q => H(q)); + lambda.qs:4:5 lambda(q); - lambda.qs:3:24 let lambda = (q => H(q)); + lambda.qs:4:5 lambda(q); PyResult { let kwargs = kwargs.unwrap_or_else(|| PyDict::new(py)); - let target = get_target_profile(&kwargs)?; + let user_profile = get_target_profile(&kwargs)?; let operation_name = get_operation_name(&kwargs)?; let search_path = get_search_path(&kwargs)?; @@ -334,11 +335,12 @@ pub(crate) fn compile_qasm_program_to_qir( let program_ty = ProgramType::File; let output_semantics = get_output_semantics(&kwargs, || OutputSemantics::OpenQasm)?; - let (package, source_map, signature) = + let (package, source_map, signature, pragma_profile) = compile_qasm_enriching_errors(res, &operation_name, program_ty, output_semantics, false)?; let package_type = PackageType::Lib; let language_features = LanguageFeatures::default(); + let target = user_profile.unwrap_or(pragma_profile.unwrap_or(Profile::Unrestricted)); let mut interpreter = create_interpreter_from_ast(package, source_map, target, language_features, package_type) .map_err(|errors| QSharpError::new_err(format_errors(errors)))?; @@ -353,7 +355,7 @@ pub(crate) fn compile_qasm_enriching_errors>( program_ty: ProgramType, output_semantics: OutputSemantics, allow_input_params: bool, -) -> PyResult<(Package, SourceMap, OperationSignature)> { +) -> PyResult<(Package, SourceMap, OperationSignature, Option)> { let config = qsc::openqasm::CompilerConfig::new( QubitSemantics::Qiskit, output_semantics.into(), @@ -364,7 +366,7 @@ pub(crate) fn compile_qasm_enriching_errors>( let unit = compile_to_qsharp_ast_with_config(semantic_parse_result, config); - let (source_map, errors, package, sig, _) = unit.into_tuple(); + let (source_map, errors, package, sig, pragma_profile) = unit.into_tuple(); if !errors.is_empty() { return Err(QasmError::new_err(format_qasm_errors(errors))); } @@ -385,7 +387,7 @@ pub(crate) fn compile_qasm_enriching_errors>( return Err(QSharpError::new_err(message)); } - Ok((package, source_map, signature)) + Ok((package, source_map, signature, pragma_profile)) } fn generate_qir_from_ast>( @@ -433,7 +435,7 @@ pub(crate) fn compile_qasm_to_qsharp( let program_ty = get_program_type(&kwargs, || ProgramType::File)?; let output_semantics = get_output_semantics(&kwargs, || OutputSemantics::OpenQasm)?; - let (package, _, _) = + let (package, _, _, _) = compile_qasm_enriching_errors(res, &operation_name, program_ty, output_semantics, true)?; let qsharp = qsc::codegen::qsharp::write_package_string(&package); @@ -596,7 +598,7 @@ pub(crate) fn circuit_qasm_program( }; let res = qsc::openqasm::semantic::parse_sources(&sources); - let (package, source_map, signature) = compile_qasm_enriching_errors( + let (package, source_map, signature, pragma_profile) = compile_qasm_enriching_errors( res, &operation_name, ProgramType::File, @@ -612,7 +614,7 @@ pub(crate) fn circuit_qasm_program( ) { TargetProfile::Adaptive_RIF.into() } else { - TargetProfile::Unrestricted.into() + pragma_profile.unwrap_or(Profile::Unrestricted) }; let mut interpreter = create_interpreter_from_ast( @@ -821,10 +823,10 @@ pub(crate) fn get_operation_name(kwargs: &Bound<'_, PyDict>) -> PyResult /// /// This also maps the `TargetProfile` exposed to Python to a `Profile` /// used by the interpreter. -pub(crate) fn get_target_profile(kwargs: &Bound<'_, PyDict>) -> PyResult { +pub(crate) fn get_target_profile(kwargs: &Bound<'_, PyDict>) -> PyResult> { match kwargs.get_item("target_profile")? { - Some(obj) => Ok(obj.extract::()?.into()), - None => Ok(TargetProfile::Unrestricted.into()), + Some(obj) => Ok(Some(obj.extract::()?.into())), + None => Ok(None), } } diff --git a/source/pip/tests-integration/resources/adaptive_rifla/output/BernsteinVaziraniNISQ.ll b/source/pip/tests-integration/resources/adaptive_rifla/output/BernsteinVaziraniNISQ.ll index 6f42df0635..13e0697e82 100644 --- a/source/pip/tests-integration/resources/adaptive_rifla/output/BernsteinVaziraniNISQ.ll +++ b/source/pip/tests-integration/resources/adaptive_rifla/output/BernsteinVaziraniNISQ.ll @@ -4,25 +4,48 @@ @3 = internal constant [6 x i8] c"3_a2r\00" @4 = internal constant [6 x i8] c"4_a3r\00" @5 = internal constant [6 x i8] c"5_a4r\00" +@array0 = internal constant [5 x ptr] [ptr inttoptr (i64 0 to ptr), ptr inttoptr (i64 1 to ptr), ptr inttoptr (i64 2 to ptr), ptr inttoptr (i64 3 to ptr), ptr inttoptr (i64 4 to ptr)] define i64 @ENTRYPOINT__main() #0 { block_0: + %var_1 = alloca i64 + %var_7 = alloca i64 call void @__quantum__rt__initialize(ptr null) call void @__quantum__qis__x__body(ptr inttoptr (i64 5 to ptr)) - call void @__quantum__qis__h__body(ptr inttoptr (i64 0 to ptr)) - call void @__quantum__qis__h__body(ptr inttoptr (i64 1 to ptr)) - call void @__quantum__qis__h__body(ptr inttoptr (i64 2 to ptr)) - call void @__quantum__qis__h__body(ptr inttoptr (i64 3 to ptr)) - call void @__quantum__qis__h__body(ptr inttoptr (i64 4 to ptr)) + store i64 0, ptr %var_1 + br label %block_1 +block_1: + %var_13 = load i64, ptr %var_1 + %var_2 = icmp slt i64 %var_13, 5 + br i1 %var_2, label %block_2, label %block_3 +block_2: + %var_19 = load i64, ptr %var_1 + %var_3 = getelementptr ptr, ptr @array0, i64 %var_19 + %var_20 = load ptr, ptr %var_3 + call void @__quantum__qis__h__body(ptr %var_20) + %var_5 = add i64 %var_19, 1 + store i64 %var_5, ptr %var_1 + br label %block_1 +block_3: call void @__quantum__qis__h__body(ptr inttoptr (i64 5 to ptr)) call void @__quantum__qis__cx__body(ptr inttoptr (i64 0 to ptr), ptr inttoptr (i64 5 to ptr)) call void @__quantum__qis__cx__body(ptr inttoptr (i64 2 to ptr), ptr inttoptr (i64 5 to ptr)) call void @__quantum__qis__cx__body(ptr inttoptr (i64 4 to ptr), ptr inttoptr (i64 5 to ptr)) - call void @__quantum__qis__h__body(ptr inttoptr (i64 4 to ptr)) - call void @__quantum__qis__h__body(ptr inttoptr (i64 3 to ptr)) - call void @__quantum__qis__h__body(ptr inttoptr (i64 2 to ptr)) - call void @__quantum__qis__h__body(ptr inttoptr (i64 1 to ptr)) - call void @__quantum__qis__h__body(ptr inttoptr (i64 0 to ptr)) + store i64 4, ptr %var_7 + br label %block_4 +block_4: + %var_15 = load i64, ptr %var_7 + %var_8 = icmp sge i64 %var_15, 0 + br i1 %var_8, label %block_5, label %block_6 +block_5: + %var_16 = load i64, ptr %var_7 + %var_9 = getelementptr ptr, ptr @array0, i64 %var_16 + %var_17 = load ptr, ptr %var_9 + call void @__quantum__qis__h__body(ptr %var_17) + %var_11 = add i64 %var_16, -1 + store i64 %var_11, ptr %var_7 + br label %block_4 +block_6: call void @__quantum__qis__mresetz__body(ptr inttoptr (i64 0 to ptr), ptr inttoptr (i64 0 to ptr)) call void @__quantum__qis__mresetz__body(ptr inttoptr (i64 1 to ptr), ptr inttoptr (i64 1 to ptr)) call void @__quantum__qis__mresetz__body(ptr inttoptr (i64 2 to ptr), ptr inttoptr (i64 2 to ptr)) diff --git a/source/pip/tests-integration/resources/adaptive_rifla/output/CopyAndUpdateExpressions.ll b/source/pip/tests-integration/resources/adaptive_rifla/output/CopyAndUpdateExpressions.ll index 6e360a9f87..01937c78d7 100644 --- a/source/pip/tests-integration/resources/adaptive_rifla/output/CopyAndUpdateExpressions.ll +++ b/source/pip/tests-integration/resources/adaptive_rifla/output/CopyAndUpdateExpressions.ll @@ -7,9 +7,11 @@ @6 = internal constant [8 x i8] c"6_t2a0r\00" @7 = internal constant [8 x i8] c"7_t2a1r\00" @8 = internal constant [8 x i8] c"8_t2a2r\00" +@array0 = internal constant [2 x ptr] [ptr inttoptr (i64 3 to ptr), ptr inttoptr (i64 4 to ptr)] define i64 @ENTRYPOINT__main() #0 { block_0: + %var_2 = alloca i64 call void @__quantum__rt__initialize(ptr null) call void @__quantum__qis__x__body(ptr inttoptr (i64 0 to ptr)) call void @__quantum__qis__mresetz__body(ptr inttoptr (i64 0 to ptr), ptr inttoptr (i64 0 to ptr)) @@ -18,8 +20,21 @@ block_0: call void @__quantum__qis__m__body(ptr inttoptr (i64 2 to ptr), ptr inttoptr (i64 2 to ptr)) call void @__quantum__qis__m__body(ptr inttoptr (i64 3 to ptr), ptr inttoptr (i64 3 to ptr)) call void @__quantum__qis__m__body(ptr inttoptr (i64 4 to ptr), ptr inttoptr (i64 4 to ptr)) - call void @__quantum__qis__x__body(ptr inttoptr (i64 3 to ptr)) - call void @__quantum__qis__x__body(ptr inttoptr (i64 4 to ptr)) + store i64 0, ptr %var_2 + br label %block_1 +block_1: + %var_8 = load i64, ptr %var_2 + %var_3 = icmp slt i64 %var_8, 2 + br i1 %var_3, label %block_2, label %block_3 +block_2: + %var_9 = load i64, ptr %var_2 + %var_4 = getelementptr ptr, ptr @array0, i64 %var_9 + %var_10 = load ptr, ptr %var_4 + call void @__quantum__qis__x__body(ptr %var_10) + %var_6 = add i64 %var_9, 1 + store i64 %var_6, ptr %var_2 + br label %block_1 +block_3: call void @__quantum__qis__m__body(ptr inttoptr (i64 3 to ptr), ptr inttoptr (i64 5 to ptr)) call void @__quantum__qis__m__body(ptr inttoptr (i64 4 to ptr), ptr inttoptr (i64 6 to ptr)) call void @__quantum__rt__tuple_record_output(i64 3, ptr @0) diff --git a/source/pip/tests-integration/resources/adaptive_rifla/output/ExpandedTests.ll b/source/pip/tests-integration/resources/adaptive_rifla/output/ExpandedTests.ll index dfe90bcc3e..0de7c928e3 100644 --- a/source/pip/tests-integration/resources/adaptive_rifla/output/ExpandedTests.ll +++ b/source/pip/tests-integration/resources/adaptive_rifla/output/ExpandedTests.ll @@ -3,49 +3,55 @@ @2 = internal constant [8 x i8] c"2_t0a0r\00" @3 = internal constant [8 x i8] c"3_t0a1r\00" @4 = internal constant [6 x i8] c"4_t1r\00" +@array0 = internal constant [2 x ptr] [ptr inttoptr (i64 0 to ptr), ptr inttoptr (i64 1 to ptr)] +@array1 = internal constant [1 x ptr] [ptr inttoptr (i64 0 to ptr)] define i64 @ENTRYPOINT__main() #0 { block_0: - %var_2 = alloca i64 - %var_4 = alloca i1 + %var_1 = alloca i64 + %var_6 = alloca i64 + %var_8 = alloca i1 + %var_9 = alloca i64 + %var_14 = alloca i64 + %var_19 = alloca i64 + %var_24 = alloca i64 + %var_29 = alloca i64 + %var_34 = alloca i64 call void @__quantum__rt__initialize(ptr null) - call void @__quantum__qis__h__body(ptr inttoptr (i64 0 to ptr)) - call void @__quantum__qis__h__body(ptr inttoptr (i64 1 to ptr)) - store i64 0, ptr %var_2 + store i64 0, ptr %var_1 br label %block_1 block_1: - %var_13 = load i64, ptr %var_2 - %var_3 = icmp sle i64 %var_13, 0 - store i1 true, ptr %var_4 - br i1 %var_3, label %block_2, label %block_3 + %var_41 = load i64, ptr %var_1 + %var_2 = icmp slt i64 %var_41, 2 + br i1 %var_2, label %block_2, label %block_3 block_2: - %var_16 = load i1, ptr %var_4 - br i1 %var_16, label %block_4, label %block_5 + %var_79 = load i64, ptr %var_1 + %var_3 = getelementptr ptr, ptr @array0, i64 %var_79 + %var_80 = load ptr, ptr %var_3 + call void @__quantum__qis__h__body(ptr %var_80) + %var_5 = add i64 %var_79, 1 + store i64 %var_5, ptr %var_1 + br label %block_1 block_3: - store i1 false, ptr %var_4 - br label %block_2 + store i64 0, ptr %var_6 + br label %block_4 block_4: + %var_43 = load i64, ptr %var_6 + %var_7 = icmp sle i64 %var_43, 0 + store i1 true, ptr %var_8 + br i1 %var_7, label %block_5, label %block_6 +block_5: + %var_46 = load i1, ptr %var_8 + br i1 %var_46, label %block_7, label %block_8 +block_6: + store i1 false, ptr %var_8 + br label %block_5 +block_7: call void @__quantum__qis__x__body(ptr inttoptr (i64 2 to ptr)) call void @__quantum__qis__h__body(ptr inttoptr (i64 2 to ptr)) - call void @__quantum__qis__x__body(ptr inttoptr (i64 0 to ptr)) - call void @__quantum__qis__ccx__body(ptr inttoptr (i64 0 to ptr), ptr inttoptr (i64 1 to ptr), ptr inttoptr (i64 2 to ptr)) - call void @__quantum__qis__x__body(ptr inttoptr (i64 0 to ptr)) - call void @__quantum__qis__h__body(ptr inttoptr (i64 2 to ptr)) - call void @__quantum__qis__x__body(ptr inttoptr (i64 2 to ptr)) - call void @__quantum__qis__h__body(ptr inttoptr (i64 1 to ptr)) - call void @__quantum__qis__h__body(ptr inttoptr (i64 0 to ptr)) - call void @__quantum__qis__x__body(ptr inttoptr (i64 0 to ptr)) - call void @__quantum__qis__x__body(ptr inttoptr (i64 1 to ptr)) - call void @__quantum__qis__cz__body(ptr inttoptr (i64 0 to ptr), ptr inttoptr (i64 1 to ptr)) - call void @__quantum__qis__x__body(ptr inttoptr (i64 1 to ptr)) - call void @__quantum__qis__x__body(ptr inttoptr (i64 0 to ptr)) - call void @__quantum__qis__h__body(ptr inttoptr (i64 0 to ptr)) - call void @__quantum__qis__h__body(ptr inttoptr (i64 1 to ptr)) - %var_17 = load i64, ptr %var_2 - %var_11 = add i64 %var_17, 1 - store i64 %var_11, ptr %var_2 - br label %block_1 -block_5: + store i64 0, ptr %var_9 + br label %block_9 +block_8: call void @__quantum__qis__mresetz__body(ptr inttoptr (i64 0 to ptr), ptr inttoptr (i64 0 to ptr)) call void @__quantum__qis__mresetz__body(ptr inttoptr (i64 1 to ptr), ptr inttoptr (i64 1 to ptr)) call void @__quantum__qis__h__body(ptr inttoptr (i64 2 to ptr)) @@ -67,6 +73,102 @@ block_5: call void @__quantum__rt__result_record_output(ptr inttoptr (i64 1 to ptr), ptr @3) call void @__quantum__rt__result_record_output(ptr inttoptr (i64 2 to ptr), ptr @4) ret i64 0 +block_9: + %var_48 = load i64, ptr %var_9 + %var_10 = icmp slt i64 %var_48, 1 + br i1 %var_10, label %block_10, label %block_11 +block_10: + %var_76 = load i64, ptr %var_9 + %var_11 = getelementptr ptr, ptr @array1, i64 %var_76 + %var_77 = load ptr, ptr %var_11 + call void @__quantum__qis__x__body(ptr %var_77) + %var_13 = add i64 %var_76, 1 + store i64 %var_13, ptr %var_9 + br label %block_9 +block_11: + call void @__quantum__qis__ccx__body(ptr inttoptr (i64 0 to ptr), ptr inttoptr (i64 1 to ptr), ptr inttoptr (i64 2 to ptr)) + store i64 0, ptr %var_14 + br label %block_12 +block_12: + %var_50 = load i64, ptr %var_14 + %var_15 = icmp sge i64 %var_50, 0 + br i1 %var_15, label %block_13, label %block_14 +block_13: + %var_73 = load i64, ptr %var_14 + %var_16 = getelementptr ptr, ptr @array1, i64 %var_73 + %var_74 = load ptr, ptr %var_16 + call void @__quantum__qis__x__body(ptr %var_74) + %var_18 = add i64 %var_73, -1 + store i64 %var_18, ptr %var_14 + br label %block_12 +block_14: + call void @__quantum__qis__h__body(ptr inttoptr (i64 2 to ptr)) + call void @__quantum__qis__x__body(ptr inttoptr (i64 2 to ptr)) + store i64 1, ptr %var_19 + br label %block_15 +block_15: + %var_52 = load i64, ptr %var_19 + %var_20 = icmp sge i64 %var_52, 0 + br i1 %var_20, label %block_16, label %block_17 +block_16: + %var_70 = load i64, ptr %var_19 + %var_21 = getelementptr ptr, ptr @array0, i64 %var_70 + %var_71 = load ptr, ptr %var_21 + call void @__quantum__qis__h__body(ptr %var_71) + %var_23 = add i64 %var_70, -1 + store i64 %var_23, ptr %var_19 + br label %block_15 +block_17: + store i64 0, ptr %var_24 + br label %block_18 +block_18: + %var_54 = load i64, ptr %var_24 + %var_25 = icmp slt i64 %var_54, 2 + br i1 %var_25, label %block_19, label %block_20 +block_19: + %var_67 = load i64, ptr %var_24 + %var_26 = getelementptr ptr, ptr @array0, i64 %var_67 + %var_68 = load ptr, ptr %var_26 + call void @__quantum__qis__x__body(ptr %var_68) + %var_28 = add i64 %var_67, 1 + store i64 %var_28, ptr %var_24 + br label %block_18 +block_20: + call void @__quantum__qis__cz__body(ptr inttoptr (i64 0 to ptr), ptr inttoptr (i64 1 to ptr)) + store i64 1, ptr %var_29 + br label %block_21 +block_21: + %var_56 = load i64, ptr %var_29 + %var_30 = icmp sge i64 %var_56, 0 + br i1 %var_30, label %block_22, label %block_23 +block_22: + %var_64 = load i64, ptr %var_29 + %var_31 = getelementptr ptr, ptr @array0, i64 %var_64 + %var_65 = load ptr, ptr %var_31 + call void @__quantum__qis__x__body(ptr %var_65) + %var_33 = add i64 %var_64, -1 + store i64 %var_33, ptr %var_29 + br label %block_21 +block_23: + store i64 0, ptr %var_34 + br label %block_24 +block_24: + %var_58 = load i64, ptr %var_34 + %var_35 = icmp slt i64 %var_58, 2 + br i1 %var_35, label %block_25, label %block_26 +block_25: + %var_61 = load i64, ptr %var_34 + %var_36 = getelementptr ptr, ptr @array0, i64 %var_61 + %var_62 = load ptr, ptr %var_36 + call void @__quantum__qis__h__body(ptr %var_62) + %var_38 = add i64 %var_61, 1 + store i64 %var_38, ptr %var_34 + br label %block_24 +block_26: + %var_59 = load i64, ptr %var_6 + %var_39 = add i64 %var_59, 1 + store i64 %var_39, ptr %var_6 + br label %block_4 } declare void @__quantum__rt__initialize(ptr) diff --git a/source/pip/tests-integration/resources/adaptive_rifla/output/HiddenShiftNISQ.ll b/source/pip/tests-integration/resources/adaptive_rifla/output/HiddenShiftNISQ.ll index 0d65fb3bc8..1bf7c28de1 100644 --- a/source/pip/tests-integration/resources/adaptive_rifla/output/HiddenShiftNISQ.ll +++ b/source/pip/tests-integration/resources/adaptive_rifla/output/HiddenShiftNISQ.ll @@ -11,145 +11,175 @@ define i64 @ENTRYPOINT__main() #0 { block_0: - %var_2 = alloca i64 - %var_3 = alloca i64 - %var_6 = alloca ptr - %var_12 = alloca i64 - %var_14 = alloca i1 - %var_18 = alloca i64 - %var_19 = alloca i64 - %var_22 = alloca ptr - %var_29 = alloca i64 - %var_31 = alloca i1 + %var_1 = alloca i64 + %var_6 = alloca i64 + %var_7 = alloca i64 + %var_10 = alloca ptr + %var_16 = alloca i64 + %var_18 = alloca i1 + %var_22 = alloca i64 + %var_23 = alloca i64 + %var_26 = alloca ptr + %var_32 = alloca i64 + %var_37 = alloca i64 + %var_39 = alloca i1 + %var_43 = alloca i64 call void @__quantum__rt__initialize(ptr null) - call void @__quantum__qis__h__body(ptr inttoptr (i64 0 to ptr)) - call void @__quantum__qis__h__body(ptr inttoptr (i64 1 to ptr)) - call void @__quantum__qis__h__body(ptr inttoptr (i64 2 to ptr)) - call void @__quantum__qis__h__body(ptr inttoptr (i64 3 to ptr)) - call void @__quantum__qis__h__body(ptr inttoptr (i64 4 to ptr)) - call void @__quantum__qis__h__body(ptr inttoptr (i64 5 to ptr)) - store i64 33, ptr %var_2 - store i64 0, ptr %var_3 + store i64 0, ptr %var_1 br label %block_1 block_1: - %var_37 = load i64, ptr %var_3 - %var_4 = icmp slt i64 %var_37, 6 - br i1 %var_4, label %block_2, label %block_3 + %var_49 = load i64, ptr %var_1 + %var_2 = icmp slt i64 %var_49, 6 + br i1 %var_2, label %block_2, label %block_3 block_2: - %var_70 = load i64, ptr %var_3 - %var_5 = getelementptr ptr, ptr @array0, i64 %var_70 - %var_71 = load ptr, ptr %var_5 - store ptr %var_71, ptr %var_6 - %var_73 = load i64, ptr %var_2 - %var_7 = and i64 %var_73, 1 - %var_8 = icmp ne i64 %var_7, 0 - br i1 %var_8, label %block_4, label %block_6 + %var_104 = load i64, ptr %var_1 + %var_3 = getelementptr ptr, ptr @array0, i64 %var_104 + %var_105 = load ptr, ptr %var_3 + call void @__quantum__qis__h__body(ptr %var_105) + %var_5 = add i64 %var_104, 1 + store i64 %var_5, ptr %var_1 + br label %block_1 block_3: - %var_38 = load i64, ptr %var_2 - %var_11 = icmp eq i64 %var_38, 0 - store i64 0, ptr %var_12 - br label %block_5 + store i64 33, ptr %var_6 + store i64 0, ptr %var_7 + br label %block_4 block_4: - %var_78 = load ptr, ptr %var_6 - call void @__quantum__qis__x__body(ptr %var_78) - br label %block_6 + %var_52 = load i64, ptr %var_7 + %var_8 = icmp slt i64 %var_52, 6 + br i1 %var_8, label %block_5, label %block_6 block_5: - %var_40 = load i64, ptr %var_12 - %var_13 = icmp sle i64 %var_40, 2 - store i1 true, ptr %var_14 - br i1 %var_13, label %block_7, label %block_8 + %var_95 = load i64, ptr %var_7 + %var_9 = getelementptr ptr, ptr @array0, i64 %var_95 + %var_96 = load ptr, ptr %var_9 + store ptr %var_96, ptr %var_10 + %var_98 = load i64, ptr %var_6 + %var_11 = and i64 %var_98, 1 + %var_12 = icmp ne i64 %var_11, 0 + br i1 %var_12, label %block_7, label %block_9 block_6: - %var_74 = load i64, ptr %var_2 - %var_9 = ashr i64 %var_74, 1 - store i64 %var_9, ptr %var_2 - %var_76 = load i64, ptr %var_3 - %var_10 = add i64 %var_76, 1 - store i64 %var_10, ptr %var_3 - br label %block_1 + %var_53 = load i64, ptr %var_6 + %var_15 = icmp eq i64 %var_53, 0 + store i64 0, ptr %var_16 + br label %block_8 block_7: - %var_43 = load i1, ptr %var_14 - br i1 %var_43, label %block_9, label %block_10 + %var_103 = load ptr, ptr %var_10 + call void @__quantum__qis__x__body(ptr %var_103) + br label %block_9 block_8: - store i1 false, ptr %var_14 - br label %block_7 + %var_55 = load i64, ptr %var_16 + %var_17 = icmp sle i64 %var_55, 2 + store i1 true, ptr %var_18 + br i1 %var_17, label %block_10, label %block_11 block_9: - %var_66 = load i64, ptr %var_12 - %var_15 = getelementptr ptr, ptr @array1, i64 %var_66 - %var_67 = load ptr, ptr %var_15 - %var_16 = getelementptr ptr, ptr @array2, i64 %var_66 - %var_68 = load ptr, ptr %var_16 - call void @__quantum__qis__cz__body(ptr %var_67, ptr %var_68) - %var_17 = add i64 %var_66, 1 - store i64 %var_17, ptr %var_12 - br label %block_5 + %var_99 = load i64, ptr %var_6 + %var_13 = ashr i64 %var_99, 1 + store i64 %var_13, ptr %var_6 + %var_101 = load i64, ptr %var_7 + %var_14 = add i64 %var_101, 1 + store i64 %var_14, ptr %var_7 + br label %block_4 block_10: - store i64 33, ptr %var_18 - store i64 0, ptr %var_19 - br label %block_11 + %var_58 = load i1, ptr %var_18 + br i1 %var_58, label %block_12, label %block_13 block_11: - %var_46 = load i64, ptr %var_19 - %var_20 = icmp slt i64 %var_46, 6 - br i1 %var_20, label %block_12, label %block_13 + store i1 false, ptr %var_18 + br label %block_10 block_12: - %var_57 = load i64, ptr %var_19 - %var_21 = getelementptr ptr, ptr @array0, i64 %var_57 - %var_58 = load ptr, ptr %var_21 - store ptr %var_58, ptr %var_22 - %var_60 = load i64, ptr %var_18 - %var_23 = and i64 %var_60, 1 - %var_24 = icmp ne i64 %var_23, 0 - br i1 %var_24, label %block_14, label %block_16 + %var_91 = load i64, ptr %var_16 + %var_19 = getelementptr ptr, ptr @array1, i64 %var_91 + %var_92 = load ptr, ptr %var_19 + %var_20 = getelementptr ptr, ptr @array2, i64 %var_91 + %var_93 = load ptr, ptr %var_20 + call void @__quantum__qis__cz__body(ptr %var_92, ptr %var_93) + %var_21 = add i64 %var_91, 1 + store i64 %var_21, ptr %var_16 + br label %block_8 block_13: - %var_47 = load i64, ptr %var_18 - %var_27 = icmp eq i64 %var_47, 0 - call void @__quantum__qis__h__body(ptr inttoptr (i64 0 to ptr)) - call void @__quantum__qis__h__body(ptr inttoptr (i64 1 to ptr)) - call void @__quantum__qis__h__body(ptr inttoptr (i64 2 to ptr)) - call void @__quantum__qis__h__body(ptr inttoptr (i64 3 to ptr)) - call void @__quantum__qis__h__body(ptr inttoptr (i64 4 to ptr)) - call void @__quantum__qis__h__body(ptr inttoptr (i64 5 to ptr)) - store i64 0, ptr %var_29 - br label %block_15 + store i64 33, ptr %var_22 + store i64 0, ptr %var_23 + br label %block_14 block_14: - %var_65 = load ptr, ptr %var_22 - call void @__quantum__qis__x__body(ptr %var_65) - br label %block_16 + %var_61 = load i64, ptr %var_23 + %var_24 = icmp slt i64 %var_61, 6 + br i1 %var_24, label %block_15, label %block_16 block_15: - %var_49 = load i64, ptr %var_29 - %var_30 = icmp sle i64 %var_49, 2 - store i1 true, ptr %var_31 - br i1 %var_30, label %block_17, label %block_18 + %var_82 = load i64, ptr %var_23 + %var_25 = getelementptr ptr, ptr @array0, i64 %var_82 + %var_83 = load ptr, ptr %var_25 + store ptr %var_83, ptr %var_26 + %var_85 = load i64, ptr %var_22 + %var_27 = and i64 %var_85, 1 + %var_28 = icmp ne i64 %var_27, 0 + br i1 %var_28, label %block_17, label %block_19 block_16: - %var_61 = load i64, ptr %var_18 - %var_25 = ashr i64 %var_61, 1 - store i64 %var_25, ptr %var_18 - %var_63 = load i64, ptr %var_19 - %var_26 = add i64 %var_63, 1 - store i64 %var_26, ptr %var_19 - br label %block_11 + %var_62 = load i64, ptr %var_22 + %var_31 = icmp eq i64 %var_62, 0 + store i64 0, ptr %var_32 + br label %block_18 block_17: - %var_52 = load i1, ptr %var_31 - br i1 %var_52, label %block_19, label %block_20 + %var_90 = load ptr, ptr %var_26 + call void @__quantum__qis__x__body(ptr %var_90) + br label %block_19 block_18: - store i1 false, ptr %var_31 - br label %block_17 + %var_64 = load i64, ptr %var_32 + %var_33 = icmp slt i64 %var_64, 6 + br i1 %var_33, label %block_20, label %block_21 block_19: - %var_53 = load i64, ptr %var_29 - %var_32 = getelementptr ptr, ptr @array1, i64 %var_53 - %var_54 = load ptr, ptr %var_32 - %var_33 = getelementptr ptr, ptr @array2, i64 %var_53 - %var_55 = load ptr, ptr %var_33 - call void @__quantum__qis__cz__body(ptr %var_54, ptr %var_55) - %var_34 = add i64 %var_53, 1 - store i64 %var_34, ptr %var_29 - br label %block_15 + %var_86 = load i64, ptr %var_22 + %var_29 = ashr i64 %var_86, 1 + store i64 %var_29, ptr %var_22 + %var_88 = load i64, ptr %var_23 + %var_30 = add i64 %var_88, 1 + store i64 %var_30, ptr %var_23 + br label %block_14 block_20: - call void @__quantum__qis__h__body(ptr inttoptr (i64 5 to ptr)) - call void @__quantum__qis__h__body(ptr inttoptr (i64 4 to ptr)) - call void @__quantum__qis__h__body(ptr inttoptr (i64 3 to ptr)) - call void @__quantum__qis__h__body(ptr inttoptr (i64 2 to ptr)) - call void @__quantum__qis__h__body(ptr inttoptr (i64 1 to ptr)) - call void @__quantum__qis__h__body(ptr inttoptr (i64 0 to ptr)) + %var_79 = load i64, ptr %var_32 + %var_34 = getelementptr ptr, ptr @array0, i64 %var_79 + %var_80 = load ptr, ptr %var_34 + call void @__quantum__qis__h__body(ptr %var_80) + %var_36 = add i64 %var_79, 1 + store i64 %var_36, ptr %var_32 + br label %block_18 +block_21: + store i64 0, ptr %var_37 + br label %block_22 +block_22: + %var_66 = load i64, ptr %var_37 + %var_38 = icmp sle i64 %var_66, 2 + store i1 true, ptr %var_39 + br i1 %var_38, label %block_23, label %block_24 +block_23: + %var_69 = load i1, ptr %var_39 + br i1 %var_69, label %block_25, label %block_26 +block_24: + store i1 false, ptr %var_39 + br label %block_23 +block_25: + %var_75 = load i64, ptr %var_37 + %var_40 = getelementptr ptr, ptr @array1, i64 %var_75 + %var_76 = load ptr, ptr %var_40 + %var_41 = getelementptr ptr, ptr @array2, i64 %var_75 + %var_77 = load ptr, ptr %var_41 + call void @__quantum__qis__cz__body(ptr %var_76, ptr %var_77) + %var_42 = add i64 %var_75, 1 + store i64 %var_42, ptr %var_37 + br label %block_22 +block_26: + store i64 5, ptr %var_43 + br label %block_27 +block_27: + %var_71 = load i64, ptr %var_43 + %var_44 = icmp sge i64 %var_71, 0 + br i1 %var_44, label %block_28, label %block_29 +block_28: + %var_72 = load i64, ptr %var_43 + %var_45 = getelementptr ptr, ptr @array0, i64 %var_72 + %var_73 = load ptr, ptr %var_45 + call void @__quantum__qis__h__body(ptr %var_73) + %var_47 = add i64 %var_72, -1 + store i64 %var_47, ptr %var_43 + br label %block_27 +block_29: call void @__quantum__qis__mresetz__body(ptr inttoptr (i64 0 to ptr), ptr inttoptr (i64 0 to ptr)) call void @__quantum__qis__mresetz__body(ptr inttoptr (i64 1 to ptr), ptr inttoptr (i64 1 to ptr)) call void @__quantum__qis__mresetz__body(ptr inttoptr (i64 2 to ptr), ptr inttoptr (i64 2 to ptr)) diff --git a/source/pip/tests-integration/resources/adaptive_rifla/output/RUS.ll b/source/pip/tests-integration/resources/adaptive_rifla/output/RUS.ll index 4ea4a01861..08fa96ebbd 100644 --- a/source/pip/tests-integration/resources/adaptive_rifla/output/RUS.ll +++ b/source/pip/tests-integration/resources/adaptive_rifla/output/RUS.ll @@ -6,24 +6,17 @@ define i64 @ENTRYPOINT__main() #0 { block_0: %var_1 = alloca i1 - %var_6 = alloca i64 + %var_2 = alloca i64 + %var_10 = alloca i64 call void @__quantum__rt__initialize(ptr null) store i1 true, ptr %var_1 br label %block_1 block_1: - %var_12 = load i1, ptr %var_1 - br i1 %var_12, label %block_2, label %block_3 + %var_16 = load i1, ptr %var_1 + br i1 %var_16, label %block_2, label %block_3 block_2: - call void @__quantum__qis__h__body(ptr inttoptr (i64 0 to ptr)) - call void @__quantum__qis__h__body(ptr inttoptr (i64 1 to ptr)) - call void @__quantum__qis__ccx__body(ptr inttoptr (i64 0 to ptr), ptr inttoptr (i64 1 to ptr), ptr inttoptr (i64 2 to ptr)) - call void @__quantum__qis__mresetz__body(ptr inttoptr (i64 2 to ptr), ptr inttoptr (i64 0 to ptr)) - %var_3 = call i1 @__quantum__rt__read_result(ptr inttoptr (i64 0 to ptr)) - %var_4 = icmp eq i1 %var_3, false - %var_5 = xor i1 %var_4, true - store i1 %var_5, ptr %var_1 - %var_14 = load i1, ptr %var_1 - br i1 %var_14, label %block_4, label %block_5 + store i64 0, ptr %var_2 + br label %block_4 block_3: call void @__quantum__qis__mresetz__body(ptr inttoptr (i64 0 to ptr), ptr inttoptr (i64 1 to ptr)) call void @__quantum__qis__mresetz__body(ptr inttoptr (i64 1 to ptr), ptr inttoptr (i64 2 to ptr)) @@ -32,24 +25,45 @@ block_3: call void @__quantum__rt__result_record_output(ptr inttoptr (i64 2 to ptr), ptr @2) ret i64 0 block_4: - store i64 0, ptr %var_6 - br label %block_6 + %var_18 = load i64, ptr %var_2 + %var_3 = icmp slt i64 %var_18, 2 + br i1 %var_3, label %block_5, label %block_6 block_5: - br label %block_1 + %var_26 = load i64, ptr %var_2 + %var_4 = getelementptr ptr, ptr @array0, i64 %var_26 + %var_27 = load ptr, ptr %var_4 + call void @__quantum__qis__h__body(ptr %var_27) + %var_6 = add i64 %var_26, 1 + store i64 %var_6, ptr %var_2 + br label %block_4 block_6: - %var_16 = load i64, ptr %var_6 - %var_7 = icmp slt i64 %var_16, 2 - br i1 %var_7, label %block_7, label %block_8 + call void @__quantum__qis__ccx__body(ptr inttoptr (i64 0 to ptr), ptr inttoptr (i64 1 to ptr), ptr inttoptr (i64 2 to ptr)) + call void @__quantum__qis__mresetz__body(ptr inttoptr (i64 2 to ptr), ptr inttoptr (i64 0 to ptr)) + %var_7 = call i1 @__quantum__rt__read_result(ptr inttoptr (i64 0 to ptr)) + %var_8 = icmp eq i1 %var_7, false + %var_9 = xor i1 %var_8, true + store i1 %var_9, ptr %var_1 + %var_20 = load i1, ptr %var_1 + br i1 %var_20, label %block_7, label %block_8 block_7: - %var_17 = load i64, ptr %var_6 - %var_8 = getelementptr ptr, ptr @array0, i64 %var_17 - %var_18 = load ptr, ptr %var_8 - call void @__quantum__qis__reset__body(ptr %var_18) - %var_10 = add i64 %var_17, 1 - store i64 %var_10, ptr %var_6 - br label %block_6 + store i64 0, ptr %var_10 + br label %block_9 block_8: - br label %block_5 + br label %block_1 +block_9: + %var_22 = load i64, ptr %var_10 + %var_11 = icmp slt i64 %var_22, 2 + br i1 %var_11, label %block_10, label %block_11 +block_10: + %var_23 = load i64, ptr %var_10 + %var_12 = getelementptr ptr, ptr @array0, i64 %var_23 + %var_24 = load ptr, ptr %var_12 + call void @__quantum__qis__reset__body(ptr %var_24) + %var_14 = add i64 %var_23, 1 + store i64 %var_14, ptr %var_10 + br label %block_9 +block_11: + br label %block_8 } declare void @__quantum__rt__initialize(ptr) diff --git a/source/pip/tests/test_interpreter.py b/source/pip/tests/test_interpreter.py index d9c4e365ab..e95e46b536 100644 --- a/source/pip/tests/test_interpreter.py +++ b/source/pip/tests/test_interpreter.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +import json from textwrap import dedent from qsharp._native import ( Interpreter, @@ -430,6 +431,52 @@ def test_qirgen() -> None: assert isinstance(qir, str) +def test_estimate_from_udt_returning_callable_matches_logical_counts_on_base_profile() -> ( + None +): + counted = None + + def make_callable(callable_value, _namespace, callable_name): + nonlocal counted + if callable_name == "Counted": + counted = callable_value + + e = Interpreter(TargetProfile.Base, make_callable=make_callable) + e.interpret( + dedent( + """ + struct Data { tally: Int } + + // The UDT output makes this a useful regression for callable + // estimation and counting on the live interpreter path. + operation Counted() : Data { + use q = Qubit(); + T(q); + MResetZ(q); + new Data { tally = 0 } + } + """ + ) + ) + + assert counted is not None + + estimate = json.loads(e.estimate("", callable=counted)) + logical_counts = e.logical_counts(callable=counted) + + assert estimate[0]["status"] == "success" + assert estimate[0]["logicalCounts"] == logical_counts + assert logical_counts == { + "numQubits": 1, + "tCount": 1, + "rotationCount": 0, + "rotationDepth": 0, + "cczCount": 0, + "ccixCount": 0, + "measurementCount": 1, + } + + def test_run_with_shots() -> None: e = Interpreter(TargetProfile.Unrestricted) @@ -560,7 +607,7 @@ def test_adaptive_errors_are_raised_from_entry_expr() -> None: assert "Qsc.CapabilitiesCk.UseOfDynamicDouble" in str(excinfo) -def test_adaptive_ri_qir_can_be_generated() -> None: +def test_adaptive_ri_entrypoint_generates_expected_qir() -> None: adaptive_input = """ namespace Test { import Std.Math.*; @@ -581,7 +628,9 @@ def test_adaptive_ri_qir_can_be_generated() -> None: e = Interpreter(TargetProfile.Adaptive_RI) e.interpret(adaptive_input) qir = e.qir("Test.Main()") - assert_expected_inline(qir, """\ + assert_expected_inline( + qir, + """\ %Result = type opaque %Qubit = type opaque @@ -618,10 +667,11 @@ def test_adaptive_ri_qir_can_be_generated() -> None: !2 = !{i32 1, !"dynamic_qubit_management", i1 false} !3 = !{i32 1, !"dynamic_result_management", i1 false} !4 = !{i32 5, !"int_computations", !{!"i64"}} -""") +""", + ) -def test_base_qir_can_be_generated() -> None: +def test_base_profile_entrypoint_generates_expected_qir() -> None: base_input = """ namespace Test { import Std.Math.*; @@ -642,7 +692,9 @@ def test_base_qir_can_be_generated() -> None: e = Interpreter(TargetProfile.Base) e.interpret(base_input) qir = e.qir("Test.Main()") - assert_expected_inline(qir, """\ + assert_expected_inline( + qir, + """\ %Result = type opaque %Qubit = type opaque @@ -678,7 +730,8 @@ def test_base_qir_can_be_generated() -> None: !1 = !{i32 7, !"qir_minor_version", i32 0} !2 = !{i32 1, !"dynamic_qubit_management", i1 false} !3 = !{i32 1, !"dynamic_result_management", i1 false} -""") +""", + ) def test_operation_circuit() -> None: diff --git a/source/pip/tests/test_qasm.py b/source/pip/tests/test_qasm.py index 7ddb4d0c54..c8dc5fdc21 100644 --- a/source/pip/tests/test_qasm.py +++ b/source/pip/tests/test_qasm.py @@ -826,6 +826,35 @@ def test_qasm_estimation() -> None: ) +def test_qasm_estimate_succeeds_for_dynamic_bool_program_rejected_by_compile() -> None: + source = """ + OPENQASM 3.0; + include "stdgates.inc"; + qubit q; + bit c; + c = measure q; + if (c) { x q; } + """ + + with pytest.raises(QSharpError, match="Qsc.CapabilitiesCk.UseOfDynamicBool"): + compile(source) + + res = estimate(source) + + assert res["status"] == "success" + assert res["physicalCounts"] is not None + assert res.logical_counts == LogicalCounts( + { + "numQubits": 1, + "tCount": 0, + "rotationCount": 0, + "rotationDepth": 0, + "cczCount": 0, + "measurementCount": 1, + } + ) + + def test_qasm_estimation_with_single_params() -> None: params = EstimatorParams() params.error_budget = 0.333 diff --git a/source/resource_estimator/src/counts/tests.rs b/source/resource_estimator/src/counts/tests.rs index f3fdc2a67f..cdbd3155f2 100644 --- a/source/resource_estimator/src/counts/tests.rs +++ b/source/resource_estimator/src/counts/tests.rs @@ -8,10 +8,12 @@ use indoc::indoc; use miette::Report; use qsc::{ LanguageFeatures, PackageType, SourceMap, TargetCapabilityFlags, - interpret::{GenericReceiver, Interpreter}, + interpret::{GenericReceiver, Interpreter, Value}, target::Profile, }; +use crate::logical_counts_call; + use super::LogicalCounter; fn verify_logical_counts(source: &str, entry: Option<&str>, expect: &Expect) { @@ -53,6 +55,14 @@ fn verify_logical_counts(source: &str, entry: Option<&str>, expect: &Expect) { } } +fn source_global(interpreter: &Interpreter, name: &str) -> Value { + interpreter + .source_globals() + .into_iter() + .find_map(|(_, global_name, value)| (global_name.as_ref() == name).then_some(value)) + .unwrap_or_else(|| panic!("{name} should be present in source globals")) +} + #[test] fn gates_are_counted() { verify_logical_counts( @@ -240,6 +250,59 @@ fn account_for_estimates_works() { ); } +#[test] +fn logical_counts_call_counts_callable_with_udt_output() { + // The callable returns a UDT so stricter backend-preparation paths would + // impose output-shape constraints here. logical_counts_call should still + // count gates by invoking the live interpreter directly. + let source = indoc! {r#" + namespace Test { + struct Data { + tally : Int + } + + operation Counted() : Data { + use q = Qubit(); + T(q); + MResetZ(q); + new Data { tally = 0 } + } + } + "#}; + let source_map = SourceMap::new([("test".into(), source.into())], None); + let (std_id, store) = qsc::compile::package_store_with_stdlib(Profile::Base.into()); + + let mut interpreter = Interpreter::new( + source_map, + PackageType::Lib, + Profile::Base.into(), + LanguageFeatures::default(), + store, + &[(std_id, None)], + ) + .expect("compilation should succeed"); + + let callable = source_global(&interpreter, "Counted"); + let counts = logical_counts_call(&mut interpreter, callable, Value::unit()) + .expect("logical counting should stay on the live interpreter path"); + + expect![[r#" + LogicalResourceCounts { + num_qubits: 1, + t_count: 1, + rotation_count: 0, + rotation_depth: 0, + ccz_count: 0, + ccix_count: 0, + measurement_count: 1, + num_compute_qubits: None, + read_from_memory_count: None, + write_to_memory_count: None, + } + "#]] + .assert_debug_eq(&counts); +} + #[test] fn pauli_i_rotation_for_global_phase_is_noop() { verify_logical_counts( diff --git a/source/samples_test/src/tests.rs b/source/samples_test/src/tests.rs index a35cc64da6..c7c3c5674c 100644 --- a/source/samples_test/src/tests.rs +++ b/source/samples_test/src/tests.rs @@ -31,6 +31,7 @@ use qsc::{ compiler::parse_and_compile_to_qsharp_ast_with_config, io::InMemorySourceResolver, }, packages::BuildableProgram, + target::Profile, }; use qsc_project::{FileSystem, ProjectType, StdFs}; @@ -124,6 +125,7 @@ fn compile_and_run_qasm_internal(source: &str, debug: bool) -> String { config, ); let (source_map, errors, package, sig, profile) = unit.into_tuple(); + let profile = profile.unwrap_or(Profile::Unrestricted); assert!(errors.is_empty(), "QASM compilation failed: {errors:?}"); let Some(signature) = sig else { diff --git a/source/samples_test/src/tests/algorithms.rs b/source/samples_test/src/tests/algorithms.rs index 3a43ba5332..9a80b31175 100644 --- a/source/samples_test/src/tests/algorithms.rs +++ b/source/samples_test/src/tests/algorithms.rs @@ -8,8 +8,8 @@ use expect_test::{Expect, expect}; // fail to compile until the new expect strings are added. pub const BERNSTEINVAZIRANI_EXPECT: Expect = expect!["[127, 238, 512]"]; pub const BERNSTEINVAZIRANI_EXPECT_DEBUG: Expect = expect!["[127, 238, 512]"]; -pub const BERNSTEINVAZIRANI_EXPECT_CIRCUIT: Expect = expect!["generated circuit of length 29822"]; -pub const BERNSTEINVAZIRANI_EXPECT_QIR: Expect = expect!["generated QIR of length 20273"]; +pub const BERNSTEINVAZIRANI_EXPECT_CIRCUIT: Expect = expect!["generated circuit of length 27618"]; +pub const BERNSTEINVAZIRANI_EXPECT_QIR: Expect = expect!["generated QIR of length 19369"]; pub const BERNSTEINVAZIRANINISQ_EXPECT: Expect = expect!["[One, Zero, One, Zero, One]"]; pub const BERNSTEINVAZIRANINISQ_EXPECT_DEBUG: Expect = expect!["[One, Zero, One, Zero, One]"]; pub const BERNSTEINVAZIRANINISQ_EXPECT_CIRCUIT: Expect = @@ -186,9 +186,10 @@ pub const SIMPLEVQE_EXPECT_DEBUG: Expect = expect![[r#" Descent done. Attempts: 52, Step: 0.0009765625, Arguments: [1.5, 1.0625], Value: 0.3216. 0.3216"#]]; // VQE sample is not expected to produce a circuit as it is too large and complex. -pub const SIMPLEVQE_EXPECT_CIRCUIT: Expect = expect!["circuit error: partial evaluation error"]; +pub const SIMPLEVQE_EXPECT_CIRCUIT: Expect = + expect!["circuit error: cannot use a dynamically-sized array"]; pub const SIMPLEVQE_EXPECT_QIR: Expect = - expect!["QIR generation error for `SimpleVQE.Main()`: partial evaluation error"]; + expect!["QIR generation error for `SimpleVQE.Main()`: cannot use a dynamically-sized array"]; pub const SUPERDENSECODING_EXPECT: Expect = expect!["((false, true), (false, true))"]; pub const SUPERDENSECODING_EXPECT_DEBUG: Expect = expect!["((false, true), (false, true))"]; pub const SUPERDENSECODING_EXPECT_CIRCUIT: Expect = expect!["generated circuit of length 4891"]; diff --git a/source/wasm/src/diagnostic.rs b/source/wasm/src/diagnostic.rs index 69126aad59..8f9a5c9696 100644 --- a/source/wasm/src/diagnostic.rs +++ b/source/wasm/src/diagnostic.rs @@ -283,6 +283,7 @@ fn interpret_error_labels(err: &interpret::Error) -> Vec