From 5511dcd210aa14ee2aacbfd1e414e7ec3b1cb1ff Mon Sep 17 00:00:00 2001 From: Eduard-Mihai Burtescu Date: Mon, 16 Oct 2023 11:51:18 +0300 Subject: [PATCH 01/22] Add `ConstKind::Undef` (i.e. SPIR-V `OpUndef`). --- README.md | 2 +- src/cfg.rs | 10 +------ src/lib.rs | 6 ++++ src/print/mod.rs | 8 ++++- src/spv/canonical.rs | 70 ++++++++++++++++++++++++++++++++++++++++++++ src/spv/lift.rs | 24 +++++++++++---- src/spv/lower.rs | 32 +++++++++++++++----- src/spv/mod.rs | 1 + src/spv/spec.rs | 1 - src/transform.rs | 4 ++- src/visit.rs | 3 +- 11 files changed, 133 insertions(+), 28 deletions(-) create mode 100644 src/spv/canonical.rs diff --git a/README.md b/README.md index 5b68e11a..515d6ed6 100644 --- a/README.md +++ b/README.md @@ -145,7 +145,7 @@ func F0() -> spv.OpTypeVoid { v6 = spv.OpIAdd(v1, 1s32): s32 (v5, v6) } else { - (spv.OpUndef: s32, spv.OpUndef: s32) + (undef: s32, undef: s32) } (v3, v4) -> (v0, v1) } while v2 diff --git a/src/cfg.rs b/src/cfg.rs index de8a6298..46037907 100644 --- a/src/cfg.rs +++ b/src/cfg.rs @@ -1568,14 +1568,6 @@ impl<'a> Structurizer<'a> { /// Create an undefined constant (as a placeholder where a value needs to be /// present, but won't actually be used), of type `ty`. fn const_undef(&self, ty: Type) -> Const { - // FIXME(eddyb) SPIR-T should have native undef itself. - let wk = &spv::spec::Spec::get().well_known; - self.cx.intern(ConstDef { - attrs: AttrSet::default(), - ty, - kind: ConstKind::SpvInst { - spv_inst_and_const_inputs: Rc::new((wk.OpUndef.into(), [].into_iter().collect())), - }, - }) + self.cx.intern(ConstDef { attrs: AttrSet::default(), ty, kind: ConstKind::Undef }) } } diff --git a/src/lib.rs b/src/lib.rs index 7723c951..d2823659 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -520,6 +520,12 @@ pub struct ConstDef { #[derive(Clone, PartialEq, Eq, Hash)] pub enum ConstKind { + /// Undeterminate value (i.e. SPIR-V `OpUndef`, LLVM `undef`). + // + // FIXME(eddyb) could it be possible to adopt LLVM's newer `poison`+`freeze` + // model, without being forced to never lift back to `OpUndef`? + Undef, + PtrToGlobalVar(GlobalVar), // HACK(eddyb) this is a fallback case that should become increasingly rare diff --git a/src/print/mod.rs b/src/print/mod.rs index c37eca03..e97f8e65 100644 --- a/src/print/mod.rs +++ b/src/print/mod.rs @@ -2581,9 +2581,14 @@ impl Print for ConstDef { AttrsAndDef { attrs: attrs.print(printer), def_without_name: compact_def.unwrap_or_else(|| match kind { + ConstKind::Undef => pretty::Fragment::new([ + printer.imperative_keyword_style().apply("undef").into(), + printer.pretty_type_ascription_suffix(*ty), + ]), &ConstKind::PtrToGlobalVar(gv) => { pretty::Fragment::new(["&".into(), gv.print(printer)]) } + ConstKind::SpvInst { spv_inst_and_const_inputs } => { let (spv_inst, const_inputs) = &**spv_inst_and_const_inputs; pretty::Fragment::new([ @@ -3294,6 +3299,8 @@ impl Print for FuncAt<'_, DataInst> { let pseudo_imm_from_value = |v: Value| { if let Value::Const(ct) = v { match &printer.cx[ct].kind { + ConstKind::Undef | ConstKind::PtrToGlobalVar(_) => {} + &ConstKind::SpvStringLiteralForExtInst(s) => { return Some(PseudoImm::Str(&printer.cx[s])); } @@ -3308,7 +3315,6 @@ impl Print for FuncAt<'_, DataInst> { } } } - ConstKind::PtrToGlobalVar(_) => {} } } None diff --git a/src/spv/canonical.rs b/src/spv/canonical.rs new file mode 100644 index 00000000..158e879d --- /dev/null +++ b/src/spv/canonical.rs @@ -0,0 +1,70 @@ +//! Bidirectional (SPIR-V <-> SPIR-T) "canonical mappings". +//! +//! Both directions are defined close together as much as possible, to: +//! - limit code duplication, making it easy to add more mappings +//! - limit how much they could even go out of sync over time +//! - prevent naming e.g. SPIR-V opcodes, outside canonicalization +// +// FIXME(eddyb) should interning attempts check/apply these canonicalizations? + +use crate::spv::{self, spec}; +use crate::ConstKind; +use lazy_static::lazy_static; + +// FIXME(eddyb) these ones could maybe make use of build script generation. +macro_rules! def_mappable_ops { + ($($op:ident),+ $(,)?) => { + #[allow(non_snake_case)] + struct MappableOps { + $($op: spec::Opcode,)+ + } + impl MappableOps { + #[inline(always)] + #[must_use] + pub fn get() -> &'static MappableOps { + lazy_static! { + static ref MAPPABLE_OPS: MappableOps = { + let spv_spec = spec::Spec::get(); + MappableOps { + $($op: spv_spec.instructions.lookup(stringify!($op)).unwrap(),)+ + } + }; + } + &MAPPABLE_OPS + } + } + }; +} +def_mappable_ops! { + OpUndef, +} + +// FIXME(eddyb) decide on a visibility scope - `pub(super)` avoids some mistakes +// (using these methods outside of `spv::{lower,lift}`), but may be too restrictive. +impl spv::Inst { + pub(super) fn as_canonical_const(&self) -> Option { + let Self { opcode, imms } = self; + let (&opcode, imms) = (opcode, &imms[..]); + + let mo = MappableOps::get(); + + if opcode == mo.OpUndef { + assert_eq!(imms.len(), 0); + Some(ConstKind::Undef) + } else { + None + } + } + + pub(super) fn from_canonical_const(const_kind: &ConstKind) -> Option { + let mo = MappableOps::get(); + + match const_kind { + ConstKind::Undef => Some(mo.OpUndef.into()), + + ConstKind::PtrToGlobalVar(_) + | ConstKind::SpvInst { .. } + | ConstKind::SpvStringLiteralForExtInst(_) => None, + } + } +} diff --git a/src/spv/lift.rs b/src/spv/lift.rs index 690dd449..a74db0f2 100644 --- a/src/spv/lift.rs +++ b/src/spv/lift.rs @@ -146,7 +146,7 @@ impl Visitor<'_> for NeedsIdsCollector<'_> { } let ct_def = &self.cx[ct]; match ct_def.kind { - ConstKind::PtrToGlobalVar(_) | ConstKind::SpvInst { .. } => { + ConstKind::Undef | ConstKind::PtrToGlobalVar(_) | ConstKind::SpvInst { .. } => { self.visit_const_def(ct_def); self.globals.insert(global); } @@ -1036,7 +1036,8 @@ impl LazyInst<'_, '_> { }; (gv_decl.attrs, import) } - ConstKind::SpvInst { .. } => (ct_def.attrs, None), + + ConstKind::Undef | ConstKind::SpvInst { .. } => (ct_def.attrs, None), // Not inserted into `globals` while visiting. ConstKind::SpvStringLiteralForExtInst(_) => unreachable!(), @@ -1123,8 +1124,19 @@ impl LazyInst<'_, '_> { }, Global::Const(ct) => { let ct_def = &cx[ct]; - match &ct_def.kind { - &ConstKind::PtrToGlobalVar(gv) => { + match spv::Inst::from_canonical_const(&ct_def.kind).ok_or(&ct_def.kind) { + Ok(spv_inst) => spv::InstWithIds { + without_ids: spv_inst, + result_type_id: Some(ids.globals[&Global::Type(ct_def.ty)]), + result_id, + ids: [].into_iter().collect(), + }, + + Err(ConstKind::Undef) => { + unreachable!("should've been handled as canonical") + } + + Err(&ConstKind::PtrToGlobalVar(gv)) => { assert!(ct_def.attrs == AttrSet::default()); let gv_decl = &module.global_vars[gv]; @@ -1157,7 +1169,7 @@ impl LazyInst<'_, '_> { } } - ConstKind::SpvInst { spv_inst_and_const_inputs } => { + Err(ConstKind::SpvInst { spv_inst_and_const_inputs }) => { let (spv_inst, const_inputs) = &**spv_inst_and_const_inputs; spv::InstWithIds { without_ids: spv_inst.clone(), @@ -1171,7 +1183,7 @@ impl LazyInst<'_, '_> { } // Not inserted into `globals` while visiting. - ConstKind::SpvStringLiteralForExtInst(_) => unreachable!(), + Err(ConstKind::SpvStringLiteralForExtInst(_)) => unreachable!(), } } }, diff --git a/src/spv/lower.rs b/src/spv/lower.rs index 1e62dc9e..6b26ac1a 100644 --- a/src/spv/lower.rs +++ b/src/spv/lower.rs @@ -580,7 +580,29 @@ impl Module { id_defs.insert(id, IdDef::Type(ty)); Seq::TypeConstOrGlobalVar - } else if inst_category == spec::InstructionCategory::Const || opcode == wk.OpUndef { + } else if let Some(const_kind) = inst.as_canonical_const() { + let id = inst.result_id.unwrap(); + assert_eq!(inst.ids.len(), 0); + + // FIXME(eddyb) this is used below for sequencing, so maybe it + // may be useful to still have some access here to `wk.OpUndef`. + let is_op_undef = matches!(const_kind, ConstKind::Undef); + + let ct = cx.intern(ConstDef { + attrs: mem::take(&mut attrs), + ty: result_type.unwrap(), + kind: const_kind, + }); + id_defs.insert(id, IdDef::Const(ct)); + + if is_op_undef { + // `OpUndef` can appear either among constants, or in a + // function, so at most advance `seq` to globals. + seq.max(Some(Seq::TypeConstOrGlobalVar)).unwrap() + } else { + Seq::TypeConstOrGlobalVar + } + } else if inst_category == spec::InstructionCategory::Const { let id = inst.result_id.unwrap(); let const_inputs = inst .ids @@ -606,13 +628,7 @@ impl Module { }); id_defs.insert(id, IdDef::Const(ct)); - if opcode == wk.OpUndef { - // `OpUndef` can appear either among constants, or in a - // function, so at most advance `seq` to globals. - seq.max(Some(Seq::TypeConstOrGlobalVar)).unwrap() - } else { - Seq::TypeConstOrGlobalVar - } + Seq::TypeConstOrGlobalVar } else if opcode == wk.OpVariable && current_func_body.is_none() { let global_var_id = inst.result_id.unwrap(); let type_of_ptr_to_global_var = result_type.unwrap(); diff --git a/src/spv/mod.rs b/src/spv/mod.rs index eb5a2e7d..09728c1a 100644 --- a/src/spv/mod.rs +++ b/src/spv/mod.rs @@ -2,6 +2,7 @@ // NOTE(eddyb) all the modules are declared here, but they're documented "inside" // (i.e. using inner doc comments). +pub mod canonical; pub mod lift; pub mod lower; pub mod print; diff --git a/src/spv/spec.rs b/src/spv/spec.rs index 2b00cb24..06b8f551 100644 --- a/src/spv/spec.rs +++ b/src/spv/spec.rs @@ -136,7 +136,6 @@ def_well_known! { OpConstantFalse, OpConstantTrue, OpConstant, - OpUndef, OpVariable, diff --git a/src/transform.rs b/src/transform.rs index 6b97b8aa..add05b20 100644 --- a/src/transform.rs +++ b/src/transform.rs @@ -457,6 +457,9 @@ impl InnerTransform for ConstDef { attrs -> transformer.transform_attr_set_use(*attrs), ty -> transformer.transform_type_use(*ty), kind -> match kind { + ConstKind::Undef + | ConstKind::SpvStringLiteralForExtInst(_) => Transformed::Unchanged, + ConstKind::PtrToGlobalVar(gv) => transform!({ gv -> transformer.transform_global_var_use(*gv), } => ConstKind::PtrToGlobalVar(gv)), @@ -470,7 +473,6 @@ impl InnerTransform for ConstDef { spv_inst_and_const_inputs: Rc::new((spv_inst.clone(), new_iter.collect())), }) } - ConstKind::SpvStringLiteralForExtInst(_) => Transformed::Unchanged }, } => Self { attrs, diff --git a/src/visit.rs b/src/visit.rs index 7bb837d5..a1ec4a73 100644 --- a/src/visit.rs +++ b/src/visit.rs @@ -336,6 +336,8 @@ impl InnerVisit for ConstDef { visitor.visit_attr_set_use(*attrs); visitor.visit_type_use(*ty); match kind { + ConstKind::Undef | ConstKind::SpvStringLiteralForExtInst(_) => {} + &ConstKind::PtrToGlobalVar(gv) => visitor.visit_global_var_use(gv), ConstKind::SpvInst { spv_inst_and_const_inputs } => { let (_spv_inst, const_inputs) = &**spv_inst_and_const_inputs; @@ -343,7 +345,6 @@ impl InnerVisit for ConstDef { visitor.visit_const_use(ct); } } - ConstKind::SpvStringLiteralForExtInst(_) => {} } } } From 0a6488001e1b33f19c1ab070450787cce47e8c85 Mon Sep 17 00:00:00 2001 From: Eduard-Mihai Burtescu Date: Mon, 16 Oct 2023 11:51:21 +0300 Subject: [PATCH 02/22] Add `TypeKind::Scalar`&`ConstKind::Scalar` for bool/int/float types&consts. --- src/cfg.rs | 35 +----- src/lib.rs | 80 ++++++++++-- src/print/mod.rs | 294 +++++++++++++++++-------------------------- src/qptr/layout.rs | 80 ++++++------ src/qptr/lift.rs | 68 ++-------- src/qptr/lower.rs | 14 +-- src/scalar.rs | 198 +++++++++++++++++++++++++++++ src/spv/canonical.rs | 170 +++++++++++++++++++++++-- src/spv/lift.rs | 84 ++++++++----- src/spv/lower.rs | 60 ++++----- src/spv/read.rs | 33 ++--- src/spv/spec.rs | 7 -- src/transform.rs | 5 +- src/visit.rs | 4 +- 14 files changed, 702 insertions(+), 430 deletions(-) create mode 100644 src/scalar.rs diff --git a/src/cfg.rs b/src/cfg.rs index 46037907..6a09cff8 100644 --- a/src/cfg.rs +++ b/src/cfg.rs @@ -1,15 +1,13 @@ //! Control-flow graph (CFG) abstractions and utilities. use crate::{ - spv, AttrSet, Const, ConstDef, ConstKind, Context, ControlNode, ControlNodeDef, + scalar, spv, AttrSet, Const, ConstDef, ConstKind, Context, ControlNode, ControlNodeDef, ControlNodeKind, ControlNodeOutputDecl, ControlRegion, ControlRegionDef, - EntityOrientedDenseMap, FuncDefBody, FxIndexMap, FxIndexSet, SelectionKind, Type, TypeKind, - Value, + EntityOrientedDenseMap, FuncDefBody, FxIndexMap, FxIndexSet, SelectionKind, Type, Value, }; use itertools::{Either, Itertools}; use smallvec::SmallVec; use std::mem; -use std::rc::Rc; /// The control-flow graph (CFG) of a function, as control-flow instructions /// ([`ControlInst`]s) attached to [`ControlRegion`]s, as an "action on exit", i.e. @@ -593,32 +591,9 @@ struct PartialControlRegion { impl<'a> Structurizer<'a> { pub fn new(cx: &'a Context, func_def_body: &'a mut FuncDefBody) -> Self { - // FIXME(eddyb) SPIR-T should have native booleans itself. - let wk = &spv::spec::Spec::get().well_known; - let type_bool = cx.intern(TypeKind::SpvInst { - spv_inst: wk.OpTypeBool.into(), - type_and_const_inputs: [].into_iter().collect(), - }); - let const_true = cx.intern(ConstDef { - attrs: AttrSet::default(), - ty: type_bool, - kind: ConstKind::SpvInst { - spv_inst_and_const_inputs: Rc::new(( - wk.OpConstantTrue.into(), - [].into_iter().collect(), - )), - }, - }); - let const_false = cx.intern(ConstDef { - attrs: AttrSet::default(), - ty: type_bool, - kind: ConstKind::SpvInst { - spv_inst_and_const_inputs: Rc::new(( - wk.OpConstantFalse.into(), - [].into_iter().collect(), - )), - }, - }); + let type_bool = cx.intern(scalar::Type::Bool); + let const_true = cx.intern(scalar::Const::TRUE); + let const_false = cx.intern(scalar::Const::FALSE); let (loop_header_to_exit_targets, incoming_edge_counts_including_loop_exits) = func_def_body diff --git a/src/lib.rs b/src/lib.rs index d2823659..c049394e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -168,6 +168,7 @@ pub mod passes { pub mod qptr; } pub mod qptr; +pub mod scalar; pub mod spv; use smallvec::SmallVec; @@ -453,16 +454,23 @@ impl Ord for OrdAssertEq { pub use context::Type; /// Definition for a [`Type`]. -// -// FIXME(eddyb) maybe special-case some basic types like integers. #[derive(PartialEq, Eq, Hash)] pub struct TypeDef { pub attrs: AttrSet, pub kind: TypeKind, } -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Clone, PartialEq, Eq, Hash, derive_more::From)] pub enum TypeKind { + /// Scalar (`bool`, integer, and floating-point) type, with limitations + /// on the supported bit-widths (power-of-two multiples of a byte). + /// + /// **Note**: pointers are never scalars (like SPIR-V, but unlike other IRs). + /// + /// See also the [`scalar`] module for more documentation and definitions. + #[from] + Scalar(scalar::Type), + /// "Quasi-pointer", an untyped pointer-like abstract scalar that can represent /// both memory locations (in any address space) and other kinds of locations /// (e.g. SPIR-V `OpVariable`s in non-memory "storage classes"). @@ -490,12 +498,18 @@ pub enum TypeKind { SpvStringLiteralForExtInst, } -// HACK(eddyb) this behaves like an implicit conversion for `cx.intern(...)`. -impl context::InternInCx for TypeKind { - fn intern_in_cx(self, cx: &Context) -> Type { - cx.intern(TypeDef { attrs: Default::default(), kind: self }) +// HACK(eddyb) this behaves like an implicit conversion for `cx.intern(...)`, +// and the macro is only used because coherence bans `impl>`. +macro_rules! impl_intern_type_kind { + ($($kind:ty),+ $(,)?) => { + $(impl context::InternInCx for $kind { + fn intern_in_cx(self, cx: &Context) -> Type { + cx.intern(TypeDef { attrs: Default::default(), kind: self.into() }) + } + })+ } } +impl_intern_type_kind!(TypeKind, scalar::Type); // HACK(eddyb) this is like `Either`, only used in `TypeKind::SpvInst`, // and only because SPIR-V type definitions can references both types and consts. @@ -505,6 +519,16 @@ pub enum TypeOrConst { Const(Const), } +// HACK(eddyb) on `Type` instead of `TypeDef` for ergonomics reasons. +impl Type { + pub fn as_scalar(self, cx: &Context) -> Option { + match cx[self].kind { + TypeKind::Scalar(ty) => Some(ty), + _ => None, + } + } +} + /// Interned handle for a [`ConstDef`](crate::ConstDef) (a constant value). pub use context::Const; @@ -518,7 +542,7 @@ pub struct ConstDef { pub kind: ConstKind, } -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Clone, PartialEq, Eq, Hash, derive_more::From)] pub enum ConstKind { /// Undeterminate value (i.e. SPIR-V `OpUndef`, LLVM `undef`). // @@ -526,6 +550,18 @@ pub enum ConstKind { // model, without being forced to never lift back to `OpUndef`? Undef, + /// Scalar (`bool`, integer, and floating-point) constant, which must have + /// a type of [`TypeKind::Scalar`] (of the same [`scalar::Type`]). + /// + /// See also the [`scalar`] module for more documentation and definitions. + // + // FIXME(eddyb) maybe document the 128-bit limitation?. + // FIXME(eddyb) this technically makes the `scalar::Type` redundant, could + // it get out of sync? (perhaps "forced canonicalization" could be used to + // enforce that interning simply doesn't allow such scenarios?). + #[from] + Scalar(scalar::Const), + PtrToGlobalVar(GlobalVar), // HACK(eddyb) this is a fallback case that should become increasingly rare @@ -540,6 +576,34 @@ pub enum ConstKind { SpvStringLiteralForExtInst(InternedStr), } +// HACK(eddyb) this behaves like an implicit conversion for `cx.intern(...)`, +// like the `TypeKind` one, but this one is even weirder because it also interns +// the inherent type of the constant, as a `Type` (with empty attributes). +macro_rules! impl_intern_const_kind { + ($($kind:ty),+ $(,)?) => { + $(impl context::InternInCx for $kind { + fn intern_in_cx(self, cx: &Context) -> Const { + cx.intern(ConstDef { + attrs: Default::default(), + ty: cx.intern(self.ty()), + kind: self.into(), + }) + } + })+ + } +} +impl_intern_const_kind!(scalar::Const); + +// HACK(eddyb) on `Const` instead of `ConstDef` for ergonomics reasons. +impl Const { + pub fn as_scalar(self, cx: &Context) -> Option<&scalar::Const> { + match &cx[self].kind { + ConstKind::Scalar(ct) => Some(ct), + _ => None, + } + } +} + /// Declarations ([`GlobalVarDecl`], [`FuncDecl`]) can contain a full definition, /// or only be an import of a definition (e.g. from another module). #[derive(Clone)] diff --git a/src/print/mod.rs b/src/print/mod.rs index e97f8e65..91637242 100644 --- a/src/print/mod.rs +++ b/src/print/mod.rs @@ -24,7 +24,7 @@ use crate::print::multiversion::Versions; use crate::qptr::{self, QPtrAttr, QPtrMemUsage, QPtrMemUsageKind, QPtrOp, QPtrUsage}; use crate::visit::{InnerVisit, Visit, Visitor}; use crate::{ - cfg, spv, AddrSpace, Attr, AttrSet, AttrSetDef, Const, ConstDef, ConstKind, Context, + cfg, scalar, spv, AddrSpace, Attr, AttrSet, AttrSetDef, Const, ConstDef, ConstKind, Context, ControlNode, ControlNodeDef, ControlNodeKind, ControlNodeOutputDecl, ControlRegion, ControlRegionDef, ControlRegionInputDecl, DataInst, DataInstDef, DataInstForm, DataInstFormDef, DataInstKind, DeclDef, Diag, DiagLevel, DiagMsgPart, EntityListIter, ExportKey, Exportee, Func, @@ -817,19 +817,13 @@ impl<'a> Printer<'a> { // here and `TypeDef`'s `Print` impl. let has_compact_print_or_is_leaf = match &ty_def.kind { TypeKind::SpvInst { spv_inst, type_and_const_inputs } => { - [ - wk.OpTypeBool, - wk.OpTypeInt, - wk.OpTypeFloat, - wk.OpTypeVector, - ] - .contains(&spv_inst.opcode) + spv_inst.opcode == wk.OpTypeVector || type_and_const_inputs.is_empty() } - TypeKind::QPtr | TypeKind::SpvStringLiteralForExtInst => { - true - } + TypeKind::Scalar(_) + | TypeKind::QPtr + | TypeKind::SpvStringLiteralForExtInst => true, }; ty_def.attrs == AttrSet::default() @@ -838,28 +832,16 @@ impl<'a> Printer<'a> { CxInterned::Const(ct) => { let ct_def = &cx[ct]; - // FIXME(eddyb) remove the duplication between - // here and `ConstDef`'s `Print` impl. - let (has_compact_print, has_nested_consts) = match &ct_def.kind - { + let has_nested_consts = match &ct_def.kind { ConstKind::SpvInst { spv_inst_and_const_inputs } => { - let (spv_inst, const_inputs) = + let (_spv_inst, const_inputs) = &**spv_inst_and_const_inputs; - ( - [ - wk.OpConstantFalse, - wk.OpConstantTrue, - wk.OpConstant, - ] - .contains(&spv_inst.opcode), - !const_inputs.is_empty(), - ) + !const_inputs.is_empty() } - _ => (false, false), + _ => false, }; - ct_def.attrs == AttrSet::default() - && (has_compact_print || !has_nested_consts) + ct_def.attrs == AttrSet::default() && !has_nested_consts } } } @@ -2380,30 +2362,13 @@ impl Print for TypeDef { // FIXME(eddyb) should this be done by lowering SPIR-V types to SPIR-T? let kw = |kw| printer.declarative_keyword_style().apply(kw).into(); + #[allow(irrefutable_let_patterns)] let compact_def = if let &TypeKind::SpvInst { spv_inst: spv::Inst { opcode, ref imms }, ref type_and_const_inputs, } = kind { - if opcode == wk.OpTypeBool { - Some(kw("bool".into())) - } else if opcode == wk.OpTypeInt { - let (width, signed) = match imms[..] { - [spv::Imm::Short(_, width), spv::Imm::Short(_, signedness)] => { - (width, signedness != 0) - } - _ => unreachable!(), - }; - - Some(if signed { kw(format!("s{width}")) } else { kw(format!("u{width}")) }) - } else if opcode == wk.OpTypeFloat { - let width = match imms[..] { - [spv::Imm::Short(_, width)] => width, - _ => unreachable!(), - }; - - Some(kw(format!("f{width}"))) - } else if opcode == wk.OpTypeVector { + if opcode == wk.OpTypeVector { let (elem_ty, elem_count) = match (&imms[..], &type_and_const_inputs[..]) { (&[spv::Imm::Short(_, elem_count)], &[TypeOrConst::Type(elem_ty)]) => { (elem_ty, elem_count) @@ -2429,6 +2394,16 @@ impl Print for TypeDef { def } else { match kind { + TypeKind::Scalar(ty) => { + let width = ty.bit_width(); + kw(match ty { + scalar::Type::Bool => "bool".into(), + scalar::Type::SInt(_) => format!("s{width}"), + scalar::Type::UInt(_) => format!("u{width}"), + scalar::Type::Float(_) => format!("f{width}"), + }) + } + // FIXME(eddyb) should this be shortened to `qtr`? TypeKind::QPtr => printer.declarative_keyword_style().apply("qptr").into(), @@ -2462,71 +2437,22 @@ impl Print for ConstDef { let wk = &spv::spec::Spec::get().well_known; let kw = |kw| printer.declarative_keyword_style().apply(kw).into(); - let literal_ty_suffix = |ty| { - pretty::Styles { - // HACK(eddyb) the exact type detracts from the value. - color_opacity: Some(0.4), - subscript: true, - ..printer.declarative_keyword_style() - } - .apply(ty) - }; - let compact_def = if let ConstKind::SpvInst { spv_inst_and_const_inputs } = kind { - let (spv_inst, _const_inputs) = &**spv_inst_and_const_inputs; - let &spv::Inst { opcode, ref imms } = spv_inst; - - if opcode == wk.OpConstantFalse { - Some(kw("false")) - } else if opcode == wk.OpConstantTrue { - Some(kw("true")) - } else if opcode == wk.OpConstant { - // HACK(eddyb) it's simpler to only handle a limited subset of - // integer/float bit-widths, for now. - let raw_bits = match imms[..] { - [spv::Imm::Short(_, x)] => Some(u64::from(x)), - [spv::Imm::LongStart(_, lo), spv::Imm::LongCont(_, hi)] => { - Some(u64::from(lo) | (u64::from(hi) << 32)) - } - _ => None, - }; - - if let ( - Some(raw_bits), - &TypeKind::SpvInst { - spv_inst: spv::Inst { opcode: ty_opcode, imms: ref ty_imms }, - .. - }, - ) = (raw_bits, &printer.cx[*ty].kind) - { - if ty_opcode == wk.OpTypeInt { - let (width, signed) = match ty_imms[..] { - [spv::Imm::Short(_, width), spv::Imm::Short(_, signedness)] => { - (width, signedness != 0) - } - _ => unreachable!(), - }; - - if width <= 64 { - let (printed_value, ty) = if signed { - let sext_raw_bits = - (raw_bits as u128 as i128) << (128 - width) >> (128 - width); - (format!("{sext_raw_bits}"), format!("s{width}")) - } else { - (format!("{raw_bits}"), format!("u{width}")) - }; - Some(pretty::Fragment::new([ - printer.numeric_literal_style().apply(printed_value), - literal_ty_suffix(ty), - ])) - } else { - None - } - } else if ty_opcode == wk.OpTypeFloat { - let width = match ty_imms[..] { - [spv::Imm::Short(_, width)] => width, - _ => unreachable!(), - }; + let def_without_name = match kind { + ConstKind::Undef => pretty::Fragment::new([ + printer.imperative_keyword_style().apply("undef").into(), + printer.pretty_type_ascription_suffix(*ty), + ]), + ConstKind::Scalar(scalar::Const::FALSE) => kw("false"), + ConstKind::Scalar(scalar::Const::TRUE) => kw("true"), + ConstKind::Scalar(ct) => { + let ty = ct.ty(); + let width = ty.bit_width(); + let (maybe_printed_value, ty_prefix) = match ty { + scalar::Type::Bool => unreachable!(), + scalar::Type::SInt(_) => (ct.int_as_i128().map(|x| x.to_string()), 's'), + scalar::Type::UInt(_) => (ct.int_as_u128().map(|x| x.to_string()), 'u'), + scalar::Type::Float(_) => { /// Check that parsing the result of printing produces /// the original bits of the floating-point value, and /// only return `Some` if that is the case. @@ -2546,69 +2472,81 @@ impl Print for ConstDef { }) } - let printed_value = match width { - 32 => bitwise_roundtrip_float_print( - raw_bits as u32, - f32::from_bits, - f32::to_bits, - ), - 64 => bitwise_roundtrip_float_print( - raw_bits, - f64::from_bits, - f64::to_bits, - ), - _ => None, - }; - printed_value.map(|s| { - pretty::Fragment::new([ - printer.numeric_literal_style().apply(s), - literal_ty_suffix(format!("f{width}")), - ]) - }) - } else { - None + ( + match width { + 32 => bitwise_roundtrip_float_print( + ct.bits() as u32, + f32::from_bits, + f32::to_bits, + ), + 64 => bitwise_roundtrip_float_print( + ct.bits() as u64, + f64::from_bits, + f64::to_bits, + ), + _ => None, + }, + 'f', + ) } - } else { - None + }; + match maybe_printed_value { + Some(printed_value) => { + let literal_ty_suffix = pretty::Styles { + // HACK(eddyb) the exact type detracts from the value. + color_opacity: Some(0.4), + subscript: true, + ..printer.declarative_keyword_style() + } + .apply(format!("{ty_prefix}{width}")); + pretty::Fragment::new([ + printer.numeric_literal_style().apply(printed_value), + literal_ty_suffix, + ]) + } + // HACK(eddyb) fallback using the bitwise representation. + None => pretty::Fragment::new([ + printer + .demote_style_for_namespace_prefix(printer.declarative_keyword_style()) + .apply(format!("{ty_prefix}{width}.")) + .into(), + printer.declarative_keyword_style().apply("from_bits").into(), + pretty::join_comma_sep( + "(", + [ + // FIXME(eddyb) consider padding this with enough + // leading zeroes for its respective width. + printer.numeric_literal_style().apply(format!("0x{:x}", ct.bits())), + ], + ")", + ), + ]), } - } else { - None } - } else { - None - }; + &ConstKind::PtrToGlobalVar(gv) => { + pretty::Fragment::new(["&".into(), gv.print(printer)]) + } - AttrsAndDef { - attrs: attrs.print(printer), - def_without_name: compact_def.unwrap_or_else(|| match kind { - ConstKind::Undef => pretty::Fragment::new([ - printer.imperative_keyword_style().apply("undef").into(), + ConstKind::SpvInst { spv_inst_and_const_inputs } => { + let (spv_inst, const_inputs) = &**spv_inst_and_const_inputs; + pretty::Fragment::new([ + printer.pretty_spv_inst( + printer.spv_op_style(), + spv_inst.opcode, + &spv_inst.imms, + const_inputs.iter().map(|ct| ct.print(printer)), + ), printer.pretty_type_ascription_suffix(*ty), - ]), - &ConstKind::PtrToGlobalVar(gv) => { - pretty::Fragment::new(["&".into(), gv.print(printer)]) - } - - ConstKind::SpvInst { spv_inst_and_const_inputs } => { - let (spv_inst, const_inputs) = &**spv_inst_and_const_inputs; - pretty::Fragment::new([ - printer.pretty_spv_inst( - printer.spv_op_style(), - spv_inst.opcode, - &spv_inst.imms, - const_inputs.iter().map(|ct| ct.print(printer)), - ), - printer.pretty_type_ascription_suffix(*ty), - ]) - } - &ConstKind::SpvStringLiteralForExtInst(s) => pretty::Fragment::new([ - printer.pretty_spv_opcode(printer.spv_op_style(), wk.OpString), - "(".into(), - printer.pretty_string_literal(&printer.cx[s]), - ")".into(), - ]), - }), - } + ]) + } + &ConstKind::SpvStringLiteralForExtInst(s) => pretty::Fragment::new([ + printer.pretty_spv_opcode(printer.spv_op_style(), wk.OpString), + "(".into(), + printer.pretty_string_literal(&printer.cx[s]), + ")".into(), + ]), + }; + AttrsAndDef { attrs: attrs.print(printer), def_without_name } } } @@ -3299,21 +3237,17 @@ impl Print for FuncAt<'_, DataInst> { let pseudo_imm_from_value = |v: Value| { if let Value::Const(ct) = v { match &printer.cx[ct].kind { - ConstKind::Undef | ConstKind::PtrToGlobalVar(_) => {} + ConstKind::Undef + | ConstKind::PtrToGlobalVar(_) + | ConstKind::SpvInst { .. } => {} &ConstKind::SpvStringLiteralForExtInst(s) => { return Some(PseudoImm::Str(&printer.cx[s])); } - ConstKind::SpvInst { spv_inst_and_const_inputs } => { - let (spv_inst, _const_inputs) = &**spv_inst_and_const_inputs; - if spv_inst.opcode == wk.OpConstant { - if let [spv::Imm::Short(_, x)] = spv_inst.imms[..] { - // HACK(eddyb) only allow unambiguously positive values. - if i32::try_from(x).and_then(u32::try_from) == Ok(x) { - return Some(PseudoImm::U32(x)); - } - } - } + // HACK(eddyb) lossless roundtrip through `i32` is most conservative + // option (only `0..=i32::MAX`, i.e. `0 <= x < 2**32, is allowed). + ConstKind::Scalar(ct) => { + return Some(PseudoImm::U32(u32::try_from(ct.int_as_i32()?).ok()?)); } } } diff --git a/src/qptr/layout.rs b/src/qptr/layout.rs index 00def111..49cd48ed 100644 --- a/src/qptr/layout.rs +++ b/src/qptr/layout.rs @@ -2,7 +2,7 @@ use crate::qptr::shapes; use crate::{ - spv, AddrSpace, Attr, Const, ConstKind, Context, Diag, FxIndexMap, Type, TypeKind, TypeOrConst, + scalar, spv, AddrSpace, Attr, Const, Context, Diag, FxIndexMap, Type, TypeKind, TypeOrConst, }; use itertools::Either; use smallvec::SmallVec; @@ -182,18 +182,10 @@ impl<'a> LayoutCache<'a> { Self { cx, wk: &spv::spec::Spec::get().well_known, config, cache: Default::default() } } - // FIXME(eddyb) properly distinguish between zero-extension and sign-extension. fn const_as_u32(&self, ct: Const) -> Option { - if let ConstKind::SpvInst { spv_inst_and_const_inputs } = &self.cx[ct].kind { - let (spv_inst, _const_inputs) = &**spv_inst_and_const_inputs; - if spv_inst.opcode == self.wk.OpConstant && spv_inst.imms.len() == 1 { - match spv_inst.imms[..] { - [spv::Imm::Short(_, x)] => return Some(x), - _ => unreachable!(), - } - } - } - None + // HACK(eddyb) lossless roundtrip through `i32` is most conservative + // option (only `0..=i32::MAX`, i.e. `0 <= x < 2**32, is allowed). + u32::try_from(ct.as_scalar(&self.cx)?.int_as_i32()?).ok() } /// Attempt to compute a `TypeLayout` for a given (SPIR-V) `Type`. @@ -202,26 +194,16 @@ impl<'a> LayoutCache<'a> { return Ok(cached); } + let layout = self.layout_of_uncached(ty)?; + self.cache.borrow_mut().insert(ty, layout.clone()); + Ok(layout) + } + + fn layout_of_uncached(&self, ty: Type) -> Result { let cx = &self.cx; let wk = self.wk; let ty_def = &cx[ty]; - let (spv_inst, type_and_const_inputs) = match &ty_def.kind { - // FIXME(eddyb) treat `QPtr`s as scalars. - TypeKind::QPtr => { - return Err(LayoutError(Diag::bug( - ["`layout_of(qptr)` (already lowered?)".into()], - ))); - } - TypeKind::SpvInst { spv_inst, type_and_const_inputs } => { - (spv_inst, type_and_const_inputs) - } - TypeKind::SpvStringLiteralForExtInst => { - return Err(LayoutError(Diag::bug([ - "`layout_of(type_of(OpString<\"...\">))`".into() - ]))); - } - }; let scalar_with_size_and_align = |(size, align)| { TypeLayout::Concrete(Rc::new(MemTypeLayout { @@ -340,25 +322,43 @@ impl<'a> LayoutCache<'a> { } } }; - let short_imm_at = |i| match spv_inst.imms[i] { - spv::Imm::Short(_, x) => x, - _ => unreachable!(), - }; // FIXME(eddyb) !!! what if... types had a min/max size and then... // that would allow surrounding offsets to limit their size... but... ugh... // ugh this doesn't make any sense. maybe if the front-end specifies // offsets with "abstract types", it must configure `qptr::layout`? - let layout = if spv_inst.opcode == wk.OpTypeBool { - // FIXME(eddyb) make this properly abstract instead of only configurable. - scalar_with_size_and_align(self.config.abstract_bool_size_align) - } else if spv_inst.opcode == wk.OpTypePointer { + + let (spv_inst, type_and_const_inputs) = match &ty_def.kind { + TypeKind::Scalar(scalar::Type::Bool) => { + // FIXME(eddyb) make this properly abstract instead of only configurable. + return Ok(scalar_with_size_and_align(self.config.abstract_bool_size_align)); + } + TypeKind::Scalar(ty) => return Ok(scalar(ty.bit_width())), + + // FIXME(eddyb) treat `QPtr`s as scalars. + TypeKind::QPtr => { + return Err(LayoutError(Diag::bug( + ["`layout_of(qptr)` (already lowered?)".into()], + ))); + } + TypeKind::SpvInst { spv_inst, type_and_const_inputs } => { + (spv_inst, type_and_const_inputs) + } + TypeKind::SpvStringLiteralForExtInst => { + return Err(LayoutError(Diag::bug([ + "`layout_of(type_of(OpString<\"...\">))`".into() + ]))); + } + }; + let short_imm_at = |i| match spv_inst.imms[i] { + spv::Imm::Short(_, x) => x, + _ => unreachable!(), + }; + Ok(if spv_inst.opcode == wk.OpTypePointer { // FIXME(eddyb) make this properly abstract instead of only configurable. // FIXME(eddyb) categorize `OpTypePointer` by storage class and split on // logical vs physical here. scalar_with_size_and_align(self.config.logical_ptr_size_align) - } else if [wk.OpTypeInt, wk.OpTypeFloat].contains(&spv_inst.opcode) { - scalar(short_imm_at(0)) } else if [wk.OpTypeVector, wk.OpTypeMatrix].contains(&spv_inst.opcode) { let len = short_imm_at(0); let (min_legacy_align, legacy_align_multiplier) = if spv_inst.opcode == wk.OpTypeVector @@ -642,8 +642,6 @@ impl<'a> LayoutCache<'a> { spv_inst.opcode.name() ) .into()]))); - }; - self.cache.borrow_mut().insert(ty, layout.clone()); - Ok(layout) + }) } } diff --git a/src/qptr/lift.rs b/src/qptr/lift.rs index f962ded5..aa5162fe 100644 --- a/src/qptr/lift.rs +++ b/src/qptr/lift.rs @@ -7,13 +7,12 @@ use crate::func_at::FuncAtMut; use crate::qptr::{shapes, QPtrAttr, QPtrMemUsage, QPtrMemUsageKind, QPtrOp, QPtrUsage}; use crate::transform::{InnerInPlaceTransform, InnerTransform, Transformed, Transformer}; use crate::{ - spv, AddrSpace, Attr, AttrSet, AttrSetDef, Const, ConstDef, ConstKind, Context, ControlNode, - ControlNodeKind, DataInst, DataInstDef, DataInstFormDef, DataInstKind, DeclDef, Diag, - DiagLevel, EntityDefs, EntityOrientedDenseMap, Func, FuncDecl, FxIndexMap, GlobalVar, + scalar, spv, AddrSpace, Attr, AttrSet, AttrSetDef, Const, ConstDef, ConstKind, Context, + ControlNode, ControlNodeKind, DataInst, DataInstDef, DataInstFormDef, DataInstKind, DeclDef, + Diag, DiagLevel, EntityDefs, EntityOrientedDenseMap, Func, FuncDecl, FxIndexMap, GlobalVar, GlobalVarDecl, Module, Type, TypeDef, TypeKind, TypeOrConst, Value, }; use smallvec::SmallVec; -use std::cell::Cell; use std::mem; use std::num::NonZeroU32; use std::rc::Rc; @@ -27,8 +26,6 @@ pub struct LiftToSpvPtrs<'a> { cx: Rc, wk: &'static spv::spec::WellKnown, layout_cache: LayoutCache<'a>, - - cached_u32_type: Cell>, } impl<'a> LiftToSpvPtrs<'a> { @@ -37,7 +34,6 @@ impl<'a> LiftToSpvPtrs<'a> { cx: cx.clone(), wk: &spv::spec::Spec::get().well_known, layout_cache: LayoutCache::new(cx, layout_config), - cached_u32_type: Default::default(), } } @@ -291,7 +287,9 @@ impl<'a> LiftToSpvPtrs<'a> { spv_inst: spv_opcode.into(), type_and_const_inputs: [TypeOrConst::Type(element_type)] .into_iter() - .chain(fixed_len.map(|len| TypeOrConst::Const(self.const_u32(len)))) + .chain(fixed_len.map(|len| { + TypeOrConst::Const(self.cx.intern(scalar::Const::from_u32(len))) + })) .collect(), }, })) @@ -329,48 +327,6 @@ impl<'a> LiftToSpvPtrs<'a> { })) } - /// Get the (likely cached) `u32` type. - fn u32_type(&self) -> Type { - if let Some(cached) = self.cached_u32_type.get() { - return cached; - } - let wk = self.wk; - let ty = self.cx.intern(TypeKind::SpvInst { - spv_inst: spv::Inst { - opcode: wk.OpTypeInt, - imms: [ - spv::Imm::Short(wk.LiteralInteger, 32), - spv::Imm::Short(wk.LiteralInteger, 0), - ] - .into_iter() - .collect(), - }, - type_and_const_inputs: [].into_iter().collect(), - }); - self.cached_u32_type.set(Some(ty)); - ty - } - - fn const_u32(&self, x: u32) -> Const { - let wk = self.wk; - - self.cx.intern(ConstDef { - attrs: AttrSet::default(), - ty: self.u32_type(), - kind: ConstKind::SpvInst { - spv_inst_and_const_inputs: Rc::new(( - spv::Inst { - opcode: wk.OpConstant, - imms: [spv::Imm::Short(wk.LiteralContextDependentNumber, x)] - .into_iter() - .collect(), - }, - [].into_iter().collect(), - )), - }, - }) - } - /// Attempt to compute a `TypeLayout` for a given (SPIR-V) `Type`. fn layout_of(&self, ty: Type) -> Result { self.layout_cache.layout_of(ty).map_err(|LayoutError(err)| LiftError(err)) @@ -644,7 +600,7 @@ impl LiftToSpvPtrInstsInFunc<'_> { ])) })?; access_chain_inputs - .push(Value::Const(self.lifter.const_u32(idx_as_i32 as u32))); + .push(Value::Const(cx.intern(scalar::Const::from_u32(idx_as_i32 as u32)))); match &layout.components { Components::Scalar => unreachable!(), @@ -757,7 +713,7 @@ impl LiftToSpvPtrInstsInFunc<'_> { ])) })?; access_chain_inputs - .push(Value::Const(self.lifter.const_u32(idx_as_i32 as u32))); + .push(Value::Const(cx.intern(scalar::Const::from_u32(idx_as_i32 as u32)))); layout = match &layout.components { Components::Scalar => unreachable!(), @@ -945,7 +901,8 @@ impl LiftToSpvPtrInstsInFunc<'_> { let mut access_chain_inputs: SmallVec<_> = [ptr].into_iter().collect(); if let TypeLayout::HandleArray(handle, _) = pointee_layout { - access_chain_inputs.push(Value::Const(self.lifter.const_u32(0))); + access_chain_inputs + .push(Value::Const(self.lifter.cx.intern(scalar::Const::from_u32(0)))); pointee_layout = TypeLayout::Handle(handle); } match (pointee_layout, access_layout) { @@ -1014,8 +971,9 @@ impl LiftToSpvPtrInstsInFunc<'_> { format!("{idx} not representable as a positive s32").into() ])) })?; - access_chain_inputs - .push(Value::Const(self.lifter.const_u32(idx_as_i32 as u32))); + access_chain_inputs.push(Value::Const( + self.lifter.cx.intern(scalar::Const::from_u32(idx_as_i32 as u32)), + )); pointee_layout = match &pointee_layout.components { Components::Scalar => unreachable!(), diff --git a/src/qptr/lower.rs b/src/qptr/lower.rs index 512b6856..56a5861e 100644 --- a/src/qptr/lower.rs +++ b/src/qptr/lower.rs @@ -171,18 +171,10 @@ impl<'a> LowerFromSpvPtrs<'a> { } } - // FIXME(eddyb) properly distinguish between zero-extension and sign-extension. fn const_as_u32(&self, ct: Const) -> Option { - if let ConstKind::SpvInst { spv_inst_and_const_inputs } = &self.cx[ct].kind { - let (spv_inst, _const_inputs) = &**spv_inst_and_const_inputs; - if spv_inst.opcode == self.wk.OpConstant && spv_inst.imms.len() == 1 { - match spv_inst.imms[..] { - [spv::Imm::Short(_, x)] => return Some(x), - _ => unreachable!(), - } - } - } - None + // HACK(eddyb) lossless roundtrip through `i32` is most conservative + // option (only `0..=i32::MAX`, i.e. `0 <= x < 2**32, is allowed). + u32::try_from(ct.as_scalar(&self.cx)?.int_as_i32()?).ok() } /// Get the (likely cached) `QPtr` type. diff --git a/src/scalar.rs b/src/scalar.rs new file mode 100644 index 00000000..e808fbe4 --- /dev/null +++ b/src/scalar.rs @@ -0,0 +1,198 @@ +//! Scalar (`bool`, integer, and floating-point) types and associated functionality. +//! +//! **Note**: pointers are never scalars (like SPIR-V, but unlike other IRs). + +// HACK(eddyb) this could be some `struct` with private fields, but this `enum` +// is only 2 bytes in size, and has better ergonomics overall. +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +pub enum Type { + Bool, + SInt(IntWidth), + UInt(IntWidth), + Float(FloatWidth), +} + +impl Type { + // HACK(eddyb) only common widths, as a convenience, expand as-needed. + pub const S32: Type = Type::SInt(IntWidth::I32); + pub const U32: Type = Type::UInt(IntWidth::I32); + pub const F32: Type = Type::Float(FloatWidth::F32); + pub const F64: Type = Type::Float(FloatWidth::F64); + + pub const fn bit_width(self) -> u32 { + match self { + Type::Bool => 1, + Type::SInt(w) | Type::UInt(w) => w.bits(), + Type::Float(w) => w.bits(), + } + } +} + +/// Bit-width of a supported integer type (only power-of-two multiples of a byte). +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +pub struct IntWidth { + // HACK(eddyb) this is so compact that only 3 bits of this byte are used + // to encode integer types from `i8` to `i128`, and so `Type` could all fit + // in one byte, but that'd need a new `enum` for `Bool`/`{S,U}Int`/`Float`. + log2_bytes: u8, +} + +impl IntWidth { + pub const I8: Self = Self::try_from_bits_unwrap(8); + pub const I16: Self = Self::try_from_bits_unwrap(16); + pub const I32: Self = Self::try_from_bits_unwrap(32); + pub const I64: Self = Self::try_from_bits_unwrap(64); + pub const I128: Self = Self::try_from_bits_unwrap(128); + + // FIXME(eddyb) remove when `Option::unwrap` is stabilized. + const fn try_from_bits_unwrap(bits: u32) -> Self { + match Self::try_from_bits(bits) { + Some(w) => w, + None => unreachable!(), + } + } + + pub const fn try_from_bits(bits: u32) -> Option { + if bits % 8 != 0 { + return None; + } + let bytes = bits / 8; + match bytes.checked_ilog2() { + Some(log2_bytes_u32) => { + let log2_bytes = log2_bytes_u32 as u8; + assert!(log2_bytes as u32 == log2_bytes_u32); + Some(Self { log2_bytes }) + } + None => None, + } + } + + pub const fn bits(self) -> u32 { + 8 * (1 << self.log2_bytes) + } +} + +/// Bit-width of a supported floating-point type (only power-of-two multiples of a byte). +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +pub struct FloatWidth(IntWidth); + +impl FloatWidth { + pub const F32: Self = Self::try_from_bits_unwrap(32); + pub const F64: Self = Self::try_from_bits_unwrap(64); + + // FIXME(eddyb) remove when `Option::unwrap` is stabilized. + const fn try_from_bits_unwrap(bits: u32) -> Self { + match Self::try_from_bits(bits) { + Some(w) => w, + None => unreachable!(), + } + } + + pub const fn try_from_bits(bits: u32) -> Option { + match IntWidth::try_from_bits(bits) { + Some(w) => Some(Self(w)), + None => None, + } + } + + pub const fn bits(self) -> u32 { + self.0.bits() + } +} + +// FIXME(eddyb) document the 128-bit limitations. +// HACK(eddyb) `(Type, u128)` would waste almost half its size on padding, and +// packing will only impact accessing the `bits`, while allowing e.g. being +// wrapped in an outer `enum`, before reaching the same size as `(u128, u128)`. +#[repr(packed)] +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +pub struct Const { + ty: Type, + bits: u128, +} + +impl Const { + pub const FALSE: Const = Const::from_bool(false); + pub const TRUE: Const = Const::from_bool(true); + + // FIXME(eddyb) document the panic conditions. + // FIXME(eddyb) make this public? + const fn from_bits_trunc(ty: Type, bits: u128) -> Const { + // FIXME(eddyb) this ensures `Const`s cannot be created when that could + // potentially need more than 128 bits for e.g. constant-folding. + let width = ty.bit_width(); + assert!(width <= 128); + + Const { ty, bits: bits & (!0u128 >> (128 - width)) } + } + + // FIXME(eddyb) document the panic conditions. + pub const fn from_bits(ty: Type, bits: u128) -> Const { + let ct_trunc = Const::from_bits_trunc(ty, bits); + assert!(ct_trunc.bits == bits); + ct_trunc + } + + pub const fn try_from_bits(ty: Type, bits: u128) -> Option { + let ct_trunc = Const::from_bits_trunc(ty, bits); + if ct_trunc.bits == bits { Some(ct_trunc) } else { None } + } + + pub const fn from_bool(v: bool) -> Const { + Const::from_bits(Type::Bool, v as u128) + } + + pub const fn from_u32(v: u32) -> Const { + Const::from_bits(Type::U32, v as u128) + } + + /// Returns `Some(ct)` iff `ty` is `{S,U}Int` and can represent `v: i128` + /// (i.e. `ct` has the same sign and absolute value as `v` does). + pub fn int_try_from_i128(ty: Type, v: i128) -> Option { + let ct_trunc = Const::from_bits_trunc(ty, v as u128); + (ct_trunc.int_as_i128() == Some(v)).then_some(ct_trunc) + } + + pub const fn ty(&self) -> Type { + self.ty + } + + pub const fn bits(&self) -> u128 { + self.bits + } + + /// Returns `Some(v)` iff `self` is `{S,U}Int` and representable by `v: i128` + /// (i.e. `self` has the same sign and absolute value as `v` does). + pub fn int_as_i128(&self) -> Option { + match self.ty { + Type::Bool | Type::Float(_) => None, + Type::SInt(_) => { + let width = self.ty.bit_width(); + Some((self.bits as i128) << (128 - width) >> (128 - width)) + } + Type::UInt(_) => self.bits.try_into().ok(), + } + } + + /// Returns `Some(v)` iff `self` is `{S,U}Int` and representable by `v: u128` + /// (i.e. `self` is positive and has the same absolute value as `v` does). + pub fn int_as_u128(&self) -> Option { + match self.ty { + Type::Bool | Type::Float(_) => None, + Type::SInt(_) => self.int_as_i128()?.try_into().ok(), + Type::UInt(_) => Some(self.bits), + } + } + + /// Returns `Some(v)` iff `self` is `{S,U}Int` and representable by `v: i32` + /// (i.e. `self` has the same sign and absolute value as `v` does). + pub fn int_as_i32(&self) -> Option { + self.int_as_i128()?.try_into().ok() + } + + /// Returns `Some(v)` iff `self` is `{S,U}Int` and representable by `v: u32` + /// (i.e. `self` is positive and has the same absolute value as `v` does). + pub fn int_as_u32(&self) -> Option { + self.int_as_u128()?.try_into().ok() + } +} diff --git a/src/spv/canonical.rs b/src/spv/canonical.rs index 158e879d..b538ee3a 100644 --- a/src/spv/canonical.rs +++ b/src/spv/canonical.rs @@ -8,7 +8,7 @@ // FIXME(eddyb) should interning attempts check/apply these canonicalizations? use crate::spv::{self, spec}; -use crate::ConstKind; +use crate::{scalar, ConstKind, Context, Type, TypeKind}; use lazy_static::lazy_static; // FIXME(eddyb) these ones could maybe make use of build script generation. @@ -36,23 +36,174 @@ macro_rules! def_mappable_ops { }; } def_mappable_ops! { + OpTypeBool, + OpTypeInt, + OpTypeFloat, + OpUndef, + OpConstantFalse, + OpConstantTrue, + OpConstant, +} + +impl scalar::Const { + fn try_decode_from_spv_imms(ty: scalar::Type, imms: &[spv::Imm]) -> Option { + // FIXME(eddyb) don't hardcode the 128-bit limitation, + // but query `scalar::Const` somehow instead. + if ty.bit_width() > 128 { + return None; + } + let imm_words = usize::try_from(ty.bit_width().div_ceil(32)).unwrap(); + if imms.len() != imm_words { + return None; + } + let mut bits = 0; + for (i, &imm) in imms.iter().enumerate() { + let w = match imm { + spv::Imm::Short(_, w) if imm_words == 1 => w, + spv::Imm::LongStart(_, w) if i == 0 && imm_words > 1 => w, + spv::Imm::LongCont(_, w) if i > 0 => w, + _ => return None, + }; + bits |= (w as u128) << (i * 32); + } + + // HACK(eddyb) signed integers are encoded sign-extended into immediates. + if let scalar::Type::SInt(_) = ty { + let imm_width = imm_words * 32; + scalar::Const::int_try_from_i128( + ty, + (bits as i128) << (128 - imm_width) >> (128 - imm_width), + ) + } else { + scalar::Const::try_from_bits(ty, bits) + } + } + + fn encode_as_spv_imms(&self) -> impl Iterator { + let wk = &spec::Spec::get().well_known; + + let ty = self.ty(); + let imm_words = ty.bit_width().div_ceil(32); + + let bits = self.bits(); + + // HACK(eddyb) signed integers are encoded sign-extended into immediates. + let bits = if let scalar::Type::SInt(_) = ty { + let imm_width = imm_words * 32; + (self.int_as_i128().unwrap() as u128) & (!0 >> (128 - imm_width)) + } else { + bits + }; + + (0..imm_words).map(move |i| { + let k = wk.LiteralContextDependentNumber; + let w = (bits >> (i * 32)) as u32; + if imm_words == 1 { + spv::Imm::Short(k, w) + } else if i == 0 { + spv::Imm::LongStart(k, w) + } else { + spv::Imm::LongCont(k, w) + } + }) + } } // FIXME(eddyb) decide on a visibility scope - `pub(super)` avoids some mistakes // (using these methods outside of `spv::{lower,lift}`), but may be too restrictive. impl spv::Inst { - pub(super) fn as_canonical_const(&self) -> Option { + // HACK(eddyb) exported only for `spv::read`/`LiteralContextDependentNumber`. + pub(super) fn int_or_float_type_bit_width(&self) -> Option { + let mo = MappableOps::get(); + + match self.imms[..] { + [spv::Imm::Short(_, bit_width), _] if self.opcode == mo.OpTypeInt => Some(bit_width), + [spv::Imm::Short(_, bit_width)] if self.opcode == mo.OpTypeFloat => Some(bit_width), + _ => None, + } + } + + // FIXME(eddyb) automate bidirectional mappings more (although the need + // for conditional, i.e. "partial", mappings, adds a lot of complexity). + pub(super) fn as_canonical_type(&self) -> Option { let Self { opcode, imms } = self; let (&opcode, imms) = (opcode, &imms[..]); let mo = MappableOps::get(); - if opcode == mo.OpUndef { - assert_eq!(imms.len(), 0); - Some(ConstKind::Undef) - } else { - None + let int_width = || scalar::IntWidth::try_from_bits(self.int_or_float_type_bit_width()?); + match imms { + [] if opcode == mo.OpTypeBool => Some(scalar::Type::Bool.into()), + &[_, spv::Imm::Short(_, 0)] if opcode == mo.OpTypeInt => { + Some(scalar::Type::UInt(int_width()?).into()) + } + &[_, spv::Imm::Short(_, 1)] if opcode == mo.OpTypeInt => { + Some(scalar::Type::SInt(int_width()?).into()) + } + [_] if opcode == mo.OpTypeFloat => Some( + scalar::Type::Float(scalar::FloatWidth::try_from_bits( + self.int_or_float_type_bit_width()?, + )?) + .into(), + ), + _ => None, + } + } + + pub(super) fn from_canonical_type(type_kind: &TypeKind) -> Option { + let wk = &spec::Spec::get().well_known; + let mo = MappableOps::get(); + + match type_kind { + &TypeKind::Scalar(ty) => match ty { + scalar::Type::Bool => Some(mo.OpTypeBool.into()), + scalar::Type::SInt(w) | scalar::Type::UInt(w) => Some(spv::Inst { + opcode: mo.OpTypeInt, + imms: [ + spv::Imm::Short(wk.LiteralInteger, w.bits()), + spv::Imm::Short( + wk.LiteralInteger, + matches!(ty, scalar::Type::SInt(_)) as u32, + ), + ] + .into_iter() + .collect(), + }), + scalar::Type::Float(w) => Some(spv::Inst { + opcode: mo.OpTypeFloat, + imms: [spv::Imm::Short(wk.LiteralInteger, w.bits())].into_iter().collect(), + }), + }, + + TypeKind::QPtr | TypeKind::SpvInst { .. } | TypeKind::SpvStringLiteralForExtInst => { + None + } + } + } + + // HACK(eddyb) this only exists as a helper for `spv::lower`. + pub(super) fn always_lower_as_const(&self) -> bool { + let mo = MappableOps::get(); + mo.OpUndef == self.opcode + } + + // FIXME(eddyb) automate bidirectional mappings more (although the need + // for conditional, i.e. "partial", mappings, adds a lot of complexity). + pub(super) fn as_canonical_const(&self, cx: &Context, ty: Type) -> Option { + let Self { opcode, imms } = self; + let (&opcode, imms) = (opcode, &imms[..]); + + let mo = MappableOps::get(); + + match imms { + [] if opcode == mo.OpUndef => Some(ConstKind::Undef), + [] if opcode == mo.OpConstantFalse => Some(scalar::Const::FALSE.into()), + [] if opcode == mo.OpConstantTrue => Some(scalar::Const::TRUE.into()), + _ if opcode == mo.OpConstant => { + Some(scalar::Const::try_decode_from_spv_imms(ty.as_scalar(cx)?, imms)?.into()) + } + _ => None, } } @@ -61,6 +212,11 @@ impl spv::Inst { match const_kind { ConstKind::Undef => Some(mo.OpUndef.into()), + ConstKind::Scalar(scalar::Const::FALSE) => Some(mo.OpConstantFalse.into()), + ConstKind::Scalar(scalar::Const::TRUE) => Some(mo.OpConstantTrue.into()), + ConstKind::Scalar(ct) => { + Some(spv::Inst { opcode: mo.OpConstant, imms: ct.encode_as_spv_imms().collect() }) + } ConstKind::PtrToGlobalVar(_) | ConstKind::SpvInst { .. } diff --git a/src/spv/lift.rs b/src/spv/lift.rs index a74db0f2..0cb139d2 100644 --- a/src/spv/lift.rs +++ b/src/spv/lift.rs @@ -4,7 +4,7 @@ use crate::func_at::FuncAt; use crate::spv::{self, spec}; use crate::visit::{InnerVisit, Visitor}; use crate::{ - cfg, AddrSpace, Attr, AttrSet, Const, ConstDef, ConstKind, Context, ControlNode, + cfg, scalar, AddrSpace, Attr, AttrSet, Const, ConstDef, ConstKind, Context, ControlNode, ControlNodeKind, ControlNodeOutputDecl, ControlRegion, ControlRegionInputDecl, DataInst, DataInstDef, DataInstForm, DataInstFormDef, DataInstKind, DeclDef, EntityList, ExportKey, Exportee, Func, FuncDecl, FuncParam, FxIndexMap, FxIndexSet, GlobalVar, GlobalVarDefBody, @@ -122,13 +122,14 @@ impl Visitor<'_> for NeedsIdsCollector<'_> { } let ty_def = &self.cx[ty]; match ty_def.kind { + TypeKind::Scalar(_) | TypeKind::SpvInst { .. } => {} + // FIXME(eddyb) this should be a proper `Result`-based error instead, // and/or `spv::lift` should mutate the module for legalization. TypeKind::QPtr => { unreachable!("`TypeKind::QPtr` should be legalized away before lifting"); } - TypeKind::SpvInst { .. } => {} TypeKind::SpvStringLiteralForExtInst => { unreachable!( "`TypeKind::SpvStringLiteralForExtInst` should not be used \ @@ -146,7 +147,10 @@ impl Visitor<'_> for NeedsIdsCollector<'_> { } let ct_def = &self.cx[ct]; match ct_def.kind { - ConstKind::Undef | ConstKind::PtrToGlobalVar(_) | ConstKind::SpvInst { .. } => { + ConstKind::Undef + | ConstKind::Scalar(_) + | ConstKind::PtrToGlobalVar(_) + | ConstKind::SpvInst { .. } => { self.visit_const_def(ct_def); self.globals.insert(global); } @@ -522,8 +526,6 @@ impl<'a> FuncLifting<'a> { func_decl: &'a FuncDecl, mut alloc_id: impl FnMut() -> Result, ) -> Result { - let wk = &spec::Spec::get().well_known; - let func_id = alloc_id()?; let param_ids = func_decl.params.iter().map(|_| alloc_id()).collect::>()?; @@ -758,15 +760,9 @@ impl<'a> FuncLifting<'a> { .collect(); let is_infinite_loop = match repeat_condition { - Value::Const(cond) => match &cx[cond].kind { - ConstKind::SpvInst { spv_inst_and_const_inputs } => { - let (spv_inst, _const_inputs) = - &**spv_inst_and_const_inputs; - spv_inst.opcode == wk.OpConstantTrue - } - _ => false, - }, - + Value::Const(cond) => { + matches!(cx[cond].kind, ConstKind::Scalar(scalar::Const::TRUE)) + } _ => false, }; if is_infinite_loop { @@ -1037,7 +1033,9 @@ impl LazyInst<'_, '_> { (gv_decl.attrs, import) } - ConstKind::Undef | ConstKind::SpvInst { .. } => (ct_def.attrs, None), + ConstKind::Undef | ConstKind::Scalar(_) | ConstKind::SpvInst { .. } => { + (ct_def.attrs, None) + } // Not inserted into `globals` while visiting. ConstKind::SpvStringLiteralForExtInst(_) => unreachable!(), @@ -1103,25 +1101,43 @@ impl LazyInst<'_, '_> { let (result_id, attrs, _) = self.result_id_attrs_and_import(module, ids); let inst = match self { Self::Global(global) => match global { - Global::Type(ty) => match &cx[ty].kind { - TypeKind::SpvInst { spv_inst, type_and_const_inputs } => spv::InstWithIds { - without_ids: spv_inst.clone(), - result_type_id: None, - result_id, - ids: type_and_const_inputs - .iter() - .map(|&ty_or_ct| { - ids.globals[&match ty_or_ct { - TypeOrConst::Type(ty) => Global::Type(ty), - TypeOrConst::Const(ct) => Global::Const(ct), - }] - }) - .collect(), - }, + Global::Type(ty) => { + let ty_def = &cx[ty]; + match spv::Inst::from_canonical_type(&ty_def.kind).ok_or(&ty_def.kind) { + Ok(spv_inst) => spv::InstWithIds { + without_ids: spv_inst, + result_type_id: None, + result_id, + ids: [].into_iter().collect(), + }, - // Not inserted into `globals` while visiting. - TypeKind::QPtr | TypeKind::SpvStringLiteralForExtInst => unreachable!(), - }, + Err(TypeKind::Scalar(_)) => { + unreachable!("should've been handled as canonical") + } + + Err(TypeKind::SpvInst { spv_inst, type_and_const_inputs }) => { + spv::InstWithIds { + without_ids: spv_inst.clone(), + result_type_id: None, + result_id, + ids: type_and_const_inputs + .iter() + .map(|&ty_or_ct| { + ids.globals[&match ty_or_ct { + TypeOrConst::Type(ty) => Global::Type(ty), + TypeOrConst::Const(ct) => Global::Const(ct), + }] + }) + .collect(), + } + } + + // Not inserted into `globals` while visiting. + Err(TypeKind::QPtr | TypeKind::SpvStringLiteralForExtInst) => { + unreachable!() + } + } + } Global::Const(ct) => { let ct_def = &cx[ct]; match spv::Inst::from_canonical_const(&ct_def.kind).ok_or(&ct_def.kind) { @@ -1132,7 +1148,7 @@ impl LazyInst<'_, '_> { ids: [].into_iter().collect(), }, - Err(ConstKind::Undef) => { + Err(ConstKind::Undef | ConstKind::Scalar(_)) => { unreachable!("should've been handled as canonical") } diff --git a/src/spv/lower.rs b/src/spv/lower.rs index 6b26ac1a..0c6a6984 100644 --- a/src/spv/lower.rs +++ b/src/spv/lower.rs @@ -557,7 +557,7 @@ impl Module { } else if inst_category == spec::InstructionCategory::Type { assert!(inst.result_type_id.is_none()); let id = inst.result_id.unwrap(); - let type_and_const_inputs = inst + let type_and_const_inputs: SmallVec<_> = inst .ids .iter() .map(|&id| match id_defs.get(&id) { @@ -575,36 +575,26 @@ impl Module { let ty = cx.intern(TypeDef { attrs: mem::take(&mut attrs), - kind: TypeKind::SpvInst { spv_inst: inst.without_ids, type_and_const_inputs }, + kind: match inst.as_canonical_type() { + Some(type_kind) => { + assert_eq!(type_and_const_inputs.len(), 0); + type_kind + } + None => { + TypeKind::SpvInst { spv_inst: inst.without_ids, type_and_const_inputs } + } + }, }); id_defs.insert(id, IdDef::Type(ty)); Seq::TypeConstOrGlobalVar - } else if let Some(const_kind) = inst.as_canonical_const() { + } else if inst_category == spec::InstructionCategory::Const + || inst.always_lower_as_const() + { let id = inst.result_id.unwrap(); - assert_eq!(inst.ids.len(), 0); + let ty = result_type.unwrap(); - // FIXME(eddyb) this is used below for sequencing, so maybe it - // may be useful to still have some access here to `wk.OpUndef`. - let is_op_undef = matches!(const_kind, ConstKind::Undef); - - let ct = cx.intern(ConstDef { - attrs: mem::take(&mut attrs), - ty: result_type.unwrap(), - kind: const_kind, - }); - id_defs.insert(id, IdDef::Const(ct)); - - if is_op_undef { - // `OpUndef` can appear either among constants, or in a - // function, so at most advance `seq` to globals. - seq.max(Some(Seq::TypeConstOrGlobalVar)).unwrap() - } else { - Seq::TypeConstOrGlobalVar - } - } else if inst_category == spec::InstructionCategory::Const { - let id = inst.result_id.unwrap(); - let const_inputs = inst + let const_inputs: SmallVec<_> = inst .ids .iter() .map(|&id| match id_defs.get(&id) { @@ -621,14 +611,26 @@ impl Module { let ct = cx.intern(ConstDef { attrs: mem::take(&mut attrs), - ty: result_type.unwrap(), - kind: ConstKind::SpvInst { - spv_inst_and_const_inputs: Rc::new((inst.without_ids, const_inputs)), + ty, + kind: match inst.as_canonical_const(&cx, ty) { + Some(const_kind) => { + assert_eq!(const_inputs.len(), 0); + const_kind + } + None => ConstKind::SpvInst { + spv_inst_and_const_inputs: Rc::new((inst.without_ids, const_inputs)), + }, }, }); id_defs.insert(id, IdDef::Const(ct)); - Seq::TypeConstOrGlobalVar + if inst_category != spec::InstructionCategory::Const { + // `OpUndef` can appear either among constants, or in a + // function, so at most advance `seq` to globals. + seq.max(Some(Seq::TypeConstOrGlobalVar)).unwrap() + } else { + Seq::TypeConstOrGlobalVar + } } else if opcode == wk.OpVariable && current_func_body.is_none() { let global_var_id = inst.result_id.unwrap(); let type_of_ptr_to_global_var = result_type.unwrap(); diff --git a/src/spv/read.rs b/src/spv/read.rs index 58178055..5acf2930 100644 --- a/src/spv/read.rs +++ b/src/spv/read.rs @@ -12,15 +12,14 @@ use std::{fs, io, iter, slice}; /// /// Used currently only to help parsing `LiteralContextDependentNumber`. enum KnownIdDef { - TypeInt(NonZeroU32), - TypeFloat(NonZeroU32), + TypeIntOrFloat(NonZeroU32), Uncategorized { opcode: spec::Opcode, result_type_id: Option }, } impl KnownIdDef { fn result_type_id(&self) -> Option { match *self { - Self::TypeInt(_) | Self::TypeFloat(_) => None, + Self::TypeIntOrFloat(_) => None, Self::Uncategorized { result_type_id, .. } => result_type_id, } } @@ -175,7 +174,7 @@ impl InstParser<'_> { .ok_or(Error::MissingContextSensitiveLiteralType)?; let extra_word_count = match *contextual_type { - KnownIdDef::TypeInt(width) | KnownIdDef::TypeFloat(width) => { + KnownIdDef::TypeIntOrFloat(width) => { // HACK(eddyb) `(width + 31) / 32 - 1` but without overflow. (width.get() - 1) / 32 } @@ -304,9 +303,6 @@ impl ModuleParser { impl Iterator for ModuleParser { type Item = io::Result; fn next(&mut self) -> Option { - let spv_spec = spec::Spec::get(); - let wk = &spv_spec.well_known; - let words = &bytemuck::cast_slice::(&self.word_bytes)[self.next_word..]; let &opcode = words.first()?; @@ -341,24 +337,11 @@ impl Iterator for ModuleParser { // HACK(eddyb) `Option::map` allows using `?` for `Result` in the closure. let maybe_known_id_result = inst.result_id.map(|id| { - let known_id_def = if opcode == wk.OpTypeInt { - KnownIdDef::TypeInt(match inst.imms[0] { - spv::Imm::Short(kind, n) => { - assert_eq!(kind, wk.LiteralInteger); - n.try_into().ok().ok_or_else(|| invalid("Width cannot be 0"))? - } - _ => unreachable!(), - }) - } else if opcode == wk.OpTypeFloat { - KnownIdDef::TypeFloat(match inst.imms[0] { - spv::Imm::Short(kind, n) => { - assert_eq!(kind, wk.LiteralInteger); - n.try_into().ok().ok_or_else(|| invalid("Width cannot be 0"))? - } - _ => unreachable!(), - }) - } else { - KnownIdDef::Uncategorized { opcode, result_type_id: inst.result_type_id } + let known_id_def = match inst.int_or_float_type_bit_width() { + Some(w) => KnownIdDef::TypeIntOrFloat( + w.try_into().ok().ok_or_else(|| invalid("Width cannot be 0"))?, + ), + None => KnownIdDef::Uncategorized { opcode, result_type_id: inst.result_type_id }, }; let old = self.known_ids.insert(id, known_id_def); diff --git a/src/spv/spec.rs b/src/spv/spec.rs index 06b8f551..7fe89260 100644 --- a/src/spv/spec.rs +++ b/src/spv/spec.rs @@ -117,9 +117,6 @@ def_well_known! { OpNoLine, OpTypeVoid, - OpTypeBool, - OpTypeInt, - OpTypeFloat, OpTypeVector, OpTypeMatrix, OpTypeArray, @@ -133,10 +130,6 @@ def_well_known! { OpTypeSampledImage, OpTypeAccelerationStructureKHR, - OpConstantFalse, - OpConstantTrue, - OpConstant, - OpVariable, OpFunction, diff --git a/src/transform.rs b/src/transform.rs index add05b20..83c626af 100644 --- a/src/transform.rs +++ b/src/transform.rs @@ -424,7 +424,9 @@ impl InnerTransform for TypeDef { transform!({ attrs -> transformer.transform_attr_set_use(*attrs), kind -> match kind { - TypeKind::QPtr | TypeKind::SpvStringLiteralForExtInst => Transformed::Unchanged, + TypeKind::Scalar(_) + | TypeKind::QPtr + | TypeKind::SpvStringLiteralForExtInst => Transformed::Unchanged, TypeKind::SpvInst { spv_inst, type_and_const_inputs } => Transformed::map_iter( type_and_const_inputs.iter(), @@ -458,6 +460,7 @@ impl InnerTransform for ConstDef { ty -> transformer.transform_type_use(*ty), kind -> match kind { ConstKind::Undef + | ConstKind::Scalar(_) | ConstKind::SpvStringLiteralForExtInst(_) => Transformed::Unchanged, ConstKind::PtrToGlobalVar(gv) => transform!({ diff --git a/src/visit.rs b/src/visit.rs index a1ec4a73..665b4a6d 100644 --- a/src/visit.rs +++ b/src/visit.rs @@ -315,7 +315,7 @@ impl InnerVisit for TypeDef { visitor.visit_attr_set_use(*attrs); match kind { - TypeKind::QPtr | TypeKind::SpvStringLiteralForExtInst => {} + TypeKind::Scalar(_) | TypeKind::QPtr | TypeKind::SpvStringLiteralForExtInst => {} TypeKind::SpvInst { spv_inst: _, type_and_const_inputs } => { for &ty_or_ct in type_and_const_inputs { @@ -336,7 +336,7 @@ impl InnerVisit for ConstDef { visitor.visit_attr_set_use(*attrs); visitor.visit_type_use(*ty); match kind { - ConstKind::Undef | ConstKind::SpvStringLiteralForExtInst(_) => {} + ConstKind::Undef | ConstKind::Scalar(_) | ConstKind::SpvStringLiteralForExtInst(_) => {} &ConstKind::PtrToGlobalVar(gv) => visitor.visit_global_var_use(gv), ConstKind::SpvInst { spv_inst_and_const_inputs } => { From f31e5a22925c9bb5fa03a86ff1d048b9b2474ba2 Mon Sep 17 00:00:00 2001 From: Eduard-Mihai Burtescu Date: Mon, 23 Oct 2023 10:12:56 +0300 Subject: [PATCH 03/22] Add `DataInstKind::Scalar` for pure scalars->scalars ops. --- README.md | 6 +- src/lib.rs | 6 + src/print/mod.rs | 13 +++ src/qptr/analyze.rs | 2 + src/qptr/lift.rs | 2 + src/qptr/lower.rs | 2 +- src/scalar.rs | 271 +++++++++++++++++++++++++++++++++++++++++++ src/spv/canonical.rs | 186 +++++++++++++++++++++++++++-- src/spv/lift.rs | 44 +++---- src/spv/lower.rs | 8 +- src/transform.rs | 4 +- src/visit.rs | 4 +- 12 files changed, 509 insertions(+), 39 deletions(-) diff --git a/README.md b/README.md index 515d6ed6..7ca18c9e 100644 --- a/README.md +++ b/README.md @@ -139,10 +139,10 @@ global_var GV0 in spv.StorageClass.Output: s32 func F0() -> spv.OpTypeVoid { loop(v0: s32 <- 1s32, v1: s32 <- 1s32) { - v2 = spv.OpSLessThan(v1, 10s32): bool + v2 = s.lt(v1, 10s32): bool (v3: s32, v4: s32) = if v2 { - v5 = spv.OpIMul(v0, v1): s32 - v6 = spv.OpIAdd(v1, 1s32): s32 + v5 = i.mul(v0, v1): s32 + v6 = i.add(v1, 1s32): s32 (v5, v6) } else { (undef: s32, undef: s32) diff --git a/src/lib.rs b/src/lib.rs index c049394e..ac67da49 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -938,6 +938,12 @@ pub struct DataInstFormDef { #[derive(Clone, PartialEq, Eq, Hash, derive_more::From)] pub enum DataInstKind { + /// Scalar (`bool`, integer, and floating-point) pure operations. + /// + /// See also the [`scalar`] module for more documentation and definitions. + #[from] + Scalar(scalar::Op), + // FIXME(eddyb) try to split this into recursive and non-recursive calls, // to avoid needing special handling for recursion where it's impossible. FuncCall(Func), diff --git a/src/print/mod.rs b/src/print/mod.rs index 91637242..cd64c5d2 100644 --- a/src/print/mod.rs +++ b/src/print/mod.rs @@ -3044,6 +3044,19 @@ impl Print for FuncAt<'_, DataInst> { let mut output_type_to_print = *output_type; let def_without_type = match kind { + &DataInstKind::Scalar(op) => { + let name = op.name(); + let (namespace_prefix, name) = name.split_at(name.find('.').unwrap() + 1); + pretty::Fragment::new([ + printer + .demote_style_for_namespace_prefix(printer.declarative_keyword_style()) + .apply(namespace_prefix) + .into(), + printer.declarative_keyword_style().apply(name).into(), + pretty::join_comma_sep("(", inputs.iter().map(|v| v.print(printer)), ")"), + ]) + } + &DataInstKind::FuncCall(func) => pretty::Fragment::new([ printer.declarative_keyword_style().apply("call").into(), " ".into(), diff --git a/src/qptr/analyze.rs b/src/qptr/analyze.rs index daaf1390..51e66351 100644 --- a/src/qptr/analyze.rs +++ b/src/qptr/analyze.rs @@ -906,6 +906,8 @@ impl<'a> InferUsage<'a> { }); }; match &data_inst_form_def.kind { + DataInstKind::Scalar(_) => {} + &DataInstKind::FuncCall(callee) => { match self.infer_usage_in_func(module, callee) { FuncInferUsageState::Complete(callee_results) => { diff --git a/src/qptr/lift.rs b/src/qptr/lift.rs index aa5162fe..f21875c9 100644 --- a/src/qptr/lift.rs +++ b/src/qptr/lift.rs @@ -404,6 +404,8 @@ impl LiftToSpvPtrInstsInFunc<'_> { Ok((addr_space, self.lifter.layout_of(pointee_type)?)) }; let replacement_data_inst_def = match &data_inst_form_def.kind { + DataInstKind::Scalar(_) => return Ok(Transformed::Unchanged), + &DataInstKind::FuncCall(_callee) => { for &v in &data_inst_def.inputs { if self.lifter.as_spv_ptr_type(type_of_val(v)).is_some() { diff --git a/src/qptr/lower.rs b/src/qptr/lower.rs index 56a5861e..87fa70a4 100644 --- a/src/qptr/lower.rs +++ b/src/qptr/lower.rs @@ -616,7 +616,7 @@ impl LowerFromSpvPtrInstsInFunc<'_> { match data_inst_form_def.kind { // Known semantics, no need to preserve SPIR-V pointer information. - DataInstKind::FuncCall(_) | DataInstKind::QPtr(_) => return, + DataInstKind::Scalar(_) | DataInstKind::FuncCall(_) | DataInstKind::QPtr(_) => return, DataInstKind::SpvInst(_) | DataInstKind::SpvExtInst { .. } => {} } diff --git a/src/scalar.rs b/src/scalar.rs index e808fbe4..29de3a50 100644 --- a/src/scalar.rs +++ b/src/scalar.rs @@ -196,3 +196,274 @@ impl Const { self.int_as_u128()?.try_into().ok() } } + +/// Pure operations with scalar inputs and outputs. +// +// FIXME(eddyb) these are not some "perfect" grouping, but allow for more +// flexibility in users of this `enum` (and its component `enum`s). +#[derive(Copy, Clone, PartialEq, Eq, Hash, derive_more::From)] +pub enum Op { + BoolUnary(BoolUnOp), + BoolBinary(BoolBinOp), + + IntUnary(IntUnOp), + IntBinary(IntBinOp), + + FloatUnary(FloatUnOp), + FloatBinary(FloatBinOp), +} + +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +pub enum BoolUnOp { + Not, +} + +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +pub enum BoolBinOp { + Eq, + // FIXME(eddyb) should this be `Xor` instead? + Ne, + Or, + And, +} + +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +pub enum IntUnOp { + Neg, + Not, + CountOnes, + + // FIXME(eddyb) ideally `Trunc` should be separated and common. + TruncOrZeroExtend, + TruncOrSignExtend, +} + +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +pub enum IntBinOp { + // I×I→I + Add, + Sub, + Mul, + DivU, + DivS, + ModU, + RemS, + ModS, + ShrU, + ShrS, + Shl, + Or, + Xor, + And, + + // I×I→I×I + CarryingAdd, + BorrowingSub, + WideningMulU, + WideningMulS, + + // I×I→B + Eq, + Ne, + // FIXME(eddyb) deduplicate between signed and unsigned. + GtU, + GtS, + GeU, + GeS, + LtU, + LtS, + LeU, + LeS, +} + +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +pub enum FloatUnOp { + // F→F + Neg, + + // F→B + IsNan, + IsInf, + + // FIXME(eddyb) these are a complicated mix of signatures. + FromUInt, + FromSInt, + ToUInt, + ToSInt, + Convert, + QuantizeAsF16, +} + +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +pub enum FloatBinOp { + // F×F→F + Add, + Sub, + Mul, + Div, + Rem, + Mod, + + // F×F→B + Cmp(FloatCmp), + // FIXME(eddyb) this doesn't properly convey that this is effectively the + // boolean flip of the opposite comparison, e.g. `CmpOrUnord(Ge)` is really + // a fused version of `Not(Cmp(Lt))`, because `x < y` is never `true` for + // unordered `x` and `y` (i.e. `PartialOrd::partial_cmp(x, y) == None`), + // but that maps to `!(x < y)` always being `true` for unordered `x` and `y`, + // and thus `x >= y` is only equivalent for the ordered cases. + CmpOrUnord(FloatCmp), +} + +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +pub enum FloatCmp { + Eq, + Ne, + Lt, + Gt, + Le, + Ge, +} + +impl Op { + pub fn output_count(self) -> usize { + match self { + Op::IntBinary(op) => op.output_count(), + _ => 1, + } + } + + pub fn name(self) -> &'static str { + match self { + Op::BoolUnary(op) => op.name(), + Op::BoolBinary(op) => op.name(), + + Op::IntUnary(op) => op.name(), + Op::IntBinary(op) => op.name(), + + Op::FloatUnary(op) => op.name(), + Op::FloatBinary(op) => op.name(), + } + } +} + +impl BoolUnOp { + pub fn name(self) -> &'static str { + match self { + BoolUnOp::Not => "bool.not", + } + } +} + +impl BoolBinOp { + pub fn name(self) -> &'static str { + match self { + BoolBinOp::Eq => "bool.eq", + BoolBinOp::Ne => "bool.ne", + BoolBinOp::Or => "bool.or", + BoolBinOp::And => "bool.and", + } + } +} + +impl IntUnOp { + pub fn name(self) -> &'static str { + match self { + IntUnOp::Neg => "i.neg", + IntUnOp::Not => "i.not", + IntUnOp::CountOnes => "i.count_ones", + + IntUnOp::TruncOrZeroExtend => "u.trunc_or_zext", + IntUnOp::TruncOrSignExtend => "s.trunc_or_sext", + } + } +} + +impl IntBinOp { + pub fn output_count(self) -> usize { + // FIXME(eddyb) should these 4 go into a different `enum`? + match self { + IntBinOp::CarryingAdd + | IntBinOp::BorrowingSub + | IntBinOp::WideningMulU + | IntBinOp::WideningMulS => 2, + _ => 1, + } + } + + pub fn name(self) -> &'static str { + match self { + IntBinOp::Add => "i.add", + IntBinOp::Sub => "i.sub", + IntBinOp::Mul => "i.mul", + IntBinOp::DivU => "u.div", + IntBinOp::DivS => "s.div", + IntBinOp::ModU => "u.mod", + IntBinOp::RemS => "s.rem", + IntBinOp::ModS => "s.mod", + IntBinOp::ShrU => "u.shr", + IntBinOp::ShrS => "s.shr", + IntBinOp::Shl => "i.shl", + IntBinOp::Or => "i.or", + IntBinOp::Xor => "i.xor", + IntBinOp::And => "i.and", + IntBinOp::CarryingAdd => "i.carrying_add", + IntBinOp::BorrowingSub => "i.borrowing_sub", + IntBinOp::WideningMulU => "u.widening_mul", + IntBinOp::WideningMulS => "s.widening_mul", + IntBinOp::Eq => "i.eq", + IntBinOp::Ne => "i.ne", + IntBinOp::GtU => "u.gt", + IntBinOp::GtS => "s.gt", + IntBinOp::GeU => "u.ge", + IntBinOp::GeS => "s.ge", + IntBinOp::LtU => "u.lt", + IntBinOp::LtS => "s.lt", + IntBinOp::LeU => "u.le", + IntBinOp::LeS => "s.le", + } + } +} + +impl FloatUnOp { + pub fn name(self) -> &'static str { + match self { + FloatUnOp::Neg => "f.neg", + + FloatUnOp::IsNan => "f.is_nan", + FloatUnOp::IsInf => "f.is_inf", + + FloatUnOp::FromUInt => "f.from_uint", + FloatUnOp::FromSInt => "f.from_sint", + FloatUnOp::ToUInt => "f.to_uint", + FloatUnOp::ToSInt => "f.to_sint", + FloatUnOp::Convert => "f.convert", + FloatUnOp::QuantizeAsF16 => "f.quantize_as_f16", + } + } +} + +impl FloatBinOp { + pub fn name(self) -> &'static str { + match self { + FloatBinOp::Add => "f.add", + FloatBinOp::Sub => "f.sub", + FloatBinOp::Mul => "f.mul", + FloatBinOp::Div => "f.div", + FloatBinOp::Rem => "f.rem", + FloatBinOp::Mod => "f.mod", + FloatBinOp::Cmp(FloatCmp::Eq) => "f.eq", + FloatBinOp::Cmp(FloatCmp::Ne) => "f.ne", + FloatBinOp::Cmp(FloatCmp::Lt) => "f.lt", + FloatBinOp::Cmp(FloatCmp::Gt) => "f.gt", + FloatBinOp::Cmp(FloatCmp::Le) => "f.le", + FloatBinOp::Cmp(FloatCmp::Ge) => "f.ge", + FloatBinOp::CmpOrUnord(FloatCmp::Eq) => "f.eq_or_unord", + FloatBinOp::CmpOrUnord(FloatCmp::Ne) => "f.ne_or_unord", + FloatBinOp::CmpOrUnord(FloatCmp::Lt) => "f.lt_or_unord", + FloatBinOp::CmpOrUnord(FloatCmp::Gt) => "f.gt_or_unord", + FloatBinOp::CmpOrUnord(FloatCmp::Le) => "f.le_or_unord", + FloatBinOp::CmpOrUnord(FloatCmp::Ge) => "f.ge_or_unord", + } + } +} diff --git a/src/spv/canonical.rs b/src/spv/canonical.rs index b538ee3a..c170e047 100644 --- a/src/spv/canonical.rs +++ b/src/spv/canonical.rs @@ -8,15 +8,21 @@ // FIXME(eddyb) should interning attempts check/apply these canonicalizations? use crate::spv::{self, spec}; -use crate::{scalar, ConstKind, Context, Type, TypeKind}; +use crate::{scalar, ConstKind, Context, DataInstKind, Type, TypeKind}; use lazy_static::lazy_static; // FIXME(eddyb) these ones could maybe make use of build script generation. macro_rules! def_mappable_ops { - ($($op:ident),+ $(,)?) => { + ( + type { $($ty_op:ident),+ $(,)? } + const { $($ct_op:ident),+ $(,)? } + $($enum_path:path { $($variant_op:ident <=> $variant:ident$(($($variant_args:tt)*))?),+ $(,)? })* + ) => { #[allow(non_snake_case)] struct MappableOps { - $($op: spec::Opcode,)+ + $($ty_op: spec::Opcode,)+ + $($ct_op: spec::Opcode,)+ + $($($variant_op: spec::Opcode,)+)* } impl MappableOps { #[inline(always)] @@ -26,24 +32,136 @@ macro_rules! def_mappable_ops { static ref MAPPABLE_OPS: MappableOps = { let spv_spec = spec::Spec::get(); MappableOps { - $($op: spv_spec.instructions.lookup(stringify!($op)).unwrap(),)+ + $($ty_op: spv_spec.instructions.lookup(stringify!($ty_op)).unwrap(),)+ + $($ct_op: spv_spec.instructions.lookup(stringify!($ct_op)).unwrap(),)+ + $($($variant_op: spv_spec.instructions.lookup(stringify!($variant_op)).unwrap(),)+)* } }; } &MAPPABLE_OPS } } + // NOTE(eddyb) these should stay private, hence not implementing `TryFrom`. + $(impl $enum_path { + fn try_from_opcode(opcode: spec::Opcode) -> Option { + let mo = MappableOps::get(); + $(if opcode == mo.$variant_op { + return Some(Self::$variant$(($($variant_args)*))?); + })+ + None + } + fn to_opcode(self) -> spec::Opcode { + let mo = MappableOps::get(); + match self { + $(Self::$variant$(($($variant_args)*))? => mo.$variant_op,)+ + } + } + })* }; } def_mappable_ops! { - OpTypeBool, - OpTypeInt, - OpTypeFloat, - - OpUndef, - OpConstantFalse, - OpConstantTrue, - OpConstant, + // FIXME(eddyb) these categories don't actually do anything right now + type { + OpTypeBool, + OpTypeInt, + OpTypeFloat, + } + const { + OpUndef, + OpConstantFalse, + OpConstantTrue, + OpConstant, + } + scalar::BoolUnOp { + OpLogicalNot <=> Not, + } + scalar::BoolBinOp { + OpLogicalEqual <=> Eq, + OpLogicalNotEqual <=> Ne, + OpLogicalOr <=> Or, + OpLogicalAnd <=> And, + } + scalar::IntUnOp { + OpSNegate <=> Neg, + OpNot <=> Not, + OpBitCount <=> CountOnes, + + OpUConvert <=> TruncOrZeroExtend, + OpSConvert <=> TruncOrSignExtend, + } + scalar::IntBinOp { + // I×I→I + OpIAdd <=> Add, + OpISub <=> Sub, + OpIMul <=> Mul, + OpUDiv <=> DivU, + OpSDiv <=> DivS, + OpUMod <=> ModU, + OpSRem <=> RemS, + OpSMod <=> ModS, + OpShiftRightLogical <=> ShrU, + OpShiftRightArithmetic <=> ShrS, + OpShiftLeftLogical <=> Shl, + OpBitwiseOr <=> Or, + OpBitwiseXor <=> Xor, + OpBitwiseAnd <=> And, + + // I×I→I×I + OpIAddCarry <=> CarryingAdd, + OpISubBorrow <=> BorrowingSub, + OpUMulExtended <=> WideningMulU, + OpSMulExtended <=> WideningMulS, + + // I×I→B + OpIEqual <=> Eq, + OpINotEqual <=> Ne, + OpUGreaterThan <=> GtU, + OpSGreaterThan <=> GtS, + OpUGreaterThanEqual <=> GeU, + OpSGreaterThanEqual <=> GeS, + OpULessThan <=> LtU, + OpSLessThan <=> LtS, + OpULessThanEqual <=> LeU, + OpSLessThanEqual <=> LeS, + } + scalar::FloatUnOp { + // F→F + OpFNegate <=> Neg, + + // F→B + OpIsNan <=> IsNan, + OpIsInf <=> IsInf, + + OpConvertUToF <=> FromUInt, + OpConvertSToF <=> FromSInt, + OpConvertFToU <=> ToUInt, + OpConvertFToS <=> ToSInt, + OpFConvert <=> Convert, + OpQuantizeToF16 <=> QuantizeAsF16, + } + scalar::FloatBinOp { + // F×F→F + OpFAdd <=> Add, + OpFSub <=> Sub, + OpFMul <=> Mul, + OpFDiv <=> Div, + OpFRem <=> Rem, + OpFMod <=> Mod, + + // F×F→B + OpFOrdEqual <=> Cmp(scalar::FloatCmp::Eq), + OpFOrdNotEqual <=> Cmp(scalar::FloatCmp::Ne), + OpFOrdLessThan <=> Cmp(scalar::FloatCmp::Lt), + OpFOrdGreaterThan <=> Cmp(scalar::FloatCmp::Gt), + OpFOrdLessThanEqual <=> Cmp(scalar::FloatCmp::Le), + OpFOrdGreaterThanEqual <=> Cmp(scalar::FloatCmp::Ge), + OpFUnordEqual <=> CmpOrUnord(scalar::FloatCmp::Eq), + OpFUnordNotEqual <=> CmpOrUnord(scalar::FloatCmp::Ne), + OpFUnordLessThan <=> CmpOrUnord(scalar::FloatCmp::Lt), + OpFUnordGreaterThan <=> CmpOrUnord(scalar::FloatCmp::Gt), + OpFUnordLessThanEqual <=> CmpOrUnord(scalar::FloatCmp::Le), + OpFUnordGreaterThanEqual <=> CmpOrUnord(scalar::FloatCmp::Ge), + } } impl scalar::Const { @@ -223,4 +341,48 @@ impl spv::Inst { | ConstKind::SpvStringLiteralForExtInst(_) => None, } } + + pub(super) fn as_canonical_data_inst_kind( + &self, + cx: &Context, + output_types: &[Type], + ) -> Option { + let Self { opcode, imms } = self; + let (&opcode, imms) = (opcode, &imms[..]); + + let scalar_op = (scalar::BoolUnOp::try_from_opcode(opcode).map(scalar::Op::from)) + .or_else(|| scalar::BoolBinOp::try_from_opcode(opcode).map(scalar::Op::from)) + .or_else(|| scalar::IntUnOp::try_from_opcode(opcode).map(scalar::Op::from)) + .or_else(|| scalar::IntBinOp::try_from_opcode(opcode).map(scalar::Op::from)) + .or_else(|| scalar::FloatUnOp::try_from_opcode(opcode).map(scalar::Op::from)) + .or_else(|| scalar::FloatBinOp::try_from_opcode(opcode).map(scalar::Op::from)); + if let Some(op) = scalar_op { + assert_eq!(imms.len(), 0); + + // FIXME(eddyb) support vector versions of these ops as well. + if output_types.len() == op.output_count() + && output_types.iter().all(|ty| ty.as_scalar(cx).is_some()) + { + Some(op.into()) + } else { + None + } + } else { + None + } + } + + pub(super) fn from_canonical_data_inst_kind(data_inst_kind: &DataInstKind) -> Option { + match data_inst_kind { + &DataInstKind::Scalar(op) => Some(match op { + scalar::Op::BoolUnary(op) => op.to_opcode().into(), + scalar::Op::BoolBinary(op) => op.to_opcode().into(), + scalar::Op::IntUnary(op) => op.to_opcode().into(), + scalar::Op::IntBinary(op) => op.to_opcode().into(), + scalar::Op::FloatUnary(op) => op.to_opcode().into(), + scalar::Op::FloatBinary(op) => op.to_opcode().into(), + }), + _ => None, + } + } } diff --git a/src/spv/lift.rs b/src/spv/lift.rs index 0cb139d2..d84bf1cb 100644 --- a/src/spv/lift.rs +++ b/src/spv/lift.rs @@ -220,7 +220,6 @@ impl Visitor<'_> for NeedsIdsCollector<'_> { } fn visit_data_inst_form_def(&mut self, data_inst_form_def: &DataInstFormDef) { - #[allow(clippy::match_same_arms)] match data_inst_form_def.kind { // FIXME(eddyb) this should be a proper `Result`-based error instead, // and/or `spv::lift` should mutate the module for legalization. @@ -228,9 +227,7 @@ impl Visitor<'_> for NeedsIdsCollector<'_> { unreachable!("`DataInstKind::QPtr` should be legalized away before lifting"); } - DataInstKind::FuncCall(_) => {} - - DataInstKind::SpvInst(_) => {} + DataInstKind::Scalar(_) | DataInstKind::FuncCall(_) | DataInstKind::SpvInst(_) => {} DataInstKind::SpvExtInst { ext_set, .. } => { self.ext_inst_imports.insert(&self.cx[ext_set]); } @@ -1256,23 +1253,30 @@ impl LazyInst<'_, '_> { }, Self::DataInst { parent_func, result_id: _, data_inst_def } => { let DataInstFormDef { kind, output_type } = &cx[data_inst_def.form]; - let (inst, extra_initial_id_operand) = match kind { - // Disallowed while visiting. - DataInstKind::QPtr(_) => unreachable!(), + let (inst, extra_initial_id_operand) = + match spv::Inst::from_canonical_data_inst_kind(kind).ok_or(kind) { + Ok(spv_inst) => (spv_inst, None), - &DataInstKind::FuncCall(callee) => { - (wk.OpFunctionCall.into(), Some(ids.funcs[&callee].func_id)) - } - DataInstKind::SpvInst(inst) => (inst.clone(), None), - &DataInstKind::SpvExtInst { ext_set, inst } => ( - spv::Inst { - opcode: wk.OpExtInst, - imms: iter::once(spv::Imm::Short(wk.LiteralExtInstInteger, inst)) - .collect(), - }, - Some(ids.ext_inst_imports[&cx[ext_set]]), - ), - }; + Err(DataInstKind::Scalar(_)) => { + unreachable!("should've been handled as canonical") + } + + // Disallowed while visiting. + Err(DataInstKind::QPtr(_)) => unreachable!(), + + Err(&DataInstKind::FuncCall(callee)) => { + (wk.OpFunctionCall.into(), Some(ids.funcs[&callee].func_id)) + } + Err(DataInstKind::SpvInst(inst)) => (inst.clone(), None), + Err(&DataInstKind::SpvExtInst { ext_set, inst }) => ( + spv::Inst { + opcode: wk.OpExtInst, + imms: iter::once(spv::Imm::Short(wk.LiteralExtInstInteger, inst)) + .collect(), + }, + Some(ids.ext_inst_imports[&cx[ext_set]]), + ), + }; spv::InstWithIds { without_ids: inst, result_type_id: output_type.map(|ty| ids.globals[&Global::Type(ty)]), diff --git a/src/spv/lower.rs b/src/spv/lower.rs index 0c6a6984..50169b27 100644 --- a/src/spv/lower.rs +++ b/src/spv/lower.rs @@ -1292,7 +1292,13 @@ impl Module { // some "structured regions" replacement for the CFG. } else { let mut ids = &ids[..]; - let kind = if opcode == wk.OpFunctionCall { + let kind = if let Some(kind) = raw_inst.without_ids.as_canonical_data_inst_kind( + &cx, + result_type.map(|ty| [ty]).as_ref().map_or(&[][..], |tys| &tys[..]), + ) { + // FIXME(eddyb) sanity-check the number/types of inputs. + kind + } else if opcode == wk.OpFunctionCall { assert!(imms.is_empty()); let callee_id = ids[0]; let maybe_callee = id_defs diff --git a/src/transform.rs b/src/transform.rs index 83c626af..6cb697a2 100644 --- a/src/transform.rs +++ b/src/transform.rs @@ -721,7 +721,9 @@ impl InnerTransform for DataInstFormDef { | QPtrOp::Load | QPtrOp::Store => Transformed::Unchanged, }, - DataInstKind::SpvInst(_) | DataInstKind::SpvExtInst { .. } => Transformed::Unchanged, + DataInstKind::Scalar(_) + | DataInstKind::SpvInst(_) + | DataInstKind::SpvExtInst { .. } => Transformed::Unchanged, }, // FIXME(eddyb) this should be replaced with an impl of `InnerTransform` // for `Option` or some other helper, to avoid "manual transpose". diff --git a/src/visit.rs b/src/visit.rs index 665b4a6d..1b5d9718 100644 --- a/src/visit.rs +++ b/src/visit.rs @@ -535,7 +535,9 @@ impl InnerVisit for DataInstFormDef { | QPtrOp::Load | QPtrOp::Store => {} }, - DataInstKind::SpvInst(_) | DataInstKind::SpvExtInst { .. } => {} + DataInstKind::Scalar(_) + | DataInstKind::SpvInst(_) + | DataInstKind::SpvExtInst { .. } => {} } if let Some(ty) = *output_type { visitor.visit_type_use(ty); From ad2278409cf2fd65511dfbd923982fdb14071875 Mon Sep 17 00:00:00 2001 From: Eduard-Mihai Burtescu Date: Tue, 7 Nov 2023 19:56:05 +0200 Subject: [PATCH 04/22] Add `SelectionKind::Switch`, using one `scalar::Const` for each case's constant. --- src/lib.rs | 13 ++- src/print/mod.rs | 43 +++++---- src/spv/canonical.rs | 9 +- src/spv/lift.rs | 36 ++++---- src/spv/lower.rs | 202 +++++++++++++++++++++++++++++-------------- src/spv/read.rs | 11 +-- src/transform.rs | 4 +- src/visit.rs | 4 +- 8 files changed, 213 insertions(+), 109 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index ac67da49..e767790f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -895,13 +895,24 @@ pub enum ControlNodeKind { }, } +// FIXME(eddyb) consider interning this, perhaps in a similar vein to `DataInstForm`. #[derive(Clone)] pub enum SelectionKind { /// Two-case selection based on boolean condition, i.e. `if`-`else`, with /// the two cases being "then" and "else" (in that order). BoolCond, - SpvInst(spv::Inst), + /// `N+1`-case selection based on comparing an integer scrutinee against + /// `N` constants, i.e. `switch`, with the last case being the "default" + /// (making it the only case without a matching entry in `case_consts`). + Switch { + // FIXME(eddyb) avoid some of the `scalar::Const` overhead here, as there + // is only a single type and we shouldn't need to store more bits per case, + // than the actual width of the integer type. + // FIXME(eddyb) consider storing this more like sorted compressed keyset, + // as there can be no duplicates, and in many cases it may be contiguous. + case_consts: Vec, + }, } /// Entity handle for a [`DataInstDef`](crate::DataInstDef) (an SSA instruction). diff --git a/src/print/mod.rs b/src/print/mod.rs index cd64c5d2..59719995 100644 --- a/src/print/mod.rs +++ b/src/print/mod.rs @@ -2953,7 +2953,7 @@ impl Print for FuncAt<'_, ControlNode> { ( pretty::join_comma_sep( "(", - input_decls_and_uses.clone().zip(initial_inputs).map( + input_decls_and_uses.clone().zip_eq(initial_inputs).map( |((input_decl, input_use), initial)| { pretty::Fragment::new([ input_decl.print(printer).insert_name_before_def( @@ -3483,7 +3483,7 @@ impl SelectionKind { mut cases: impl ExactSizeIterator, ) -> pretty::Fragment { let kw = |kw| kw_style.apply(kw).into(); - match *self { + match self { SelectionKind::BoolCond => { assert_eq!(cases.len(), 2); let [then_case, else_case] = [cases.next().unwrap(), cases.next().unwrap()]; @@ -3500,27 +3500,36 @@ impl SelectionKind { "}".into(), ]) } - SelectionKind::SpvInst(spv::Inst { opcode, ref imms }) => { - let header = printer.pretty_spv_inst( - kw_style, - opcode, - imms, - [Some(scrutinee.print(printer))] - .into_iter() - .chain((0..cases.len()).map(|_| None)), - ); + SelectionKind::Switch { case_consts } => { + assert_eq!(cases.len(), case_consts.len() + 1); + + let case_patterns = case_consts + .iter() + .map(|&ct| { + let int_to_string = (ct.int_as_u128().map(|x| x.to_string())) + .or_else(|| ct.int_as_i128().map(|x| x.to_string())); + match int_to_string { + Some(v) => printer.numeric_literal_style().apply(v).into(), + None => { + let ct: Const = printer.cx.intern(ct); + ct.print(printer) + } + } + }) + .chain(["_".into()]); pretty::Fragment::new([ - header, + kw("switch"), + " ".into(), + scrutinee.print(printer), " {".into(), pretty::Node::IndentedBlock( - cases - .map(|case| { + case_patterns + .zip_eq(cases) + .map(|(case_pattern, case)| { pretty::Fragment::new([ pretty::Node::ForceLineSeparation.into(), - // FIXME(eddyb) this should pull information out - // of the instruction to be more precise. - kw("case"), + case_pattern, " => {".into(), pretty::Node::IndentedBlock(vec![case]).into(), "}".into(), diff --git a/src/spv/canonical.rs b/src/spv/canonical.rs index c170e047..e99d21dd 100644 --- a/src/spv/canonical.rs +++ b/src/spv/canonical.rs @@ -165,7 +165,11 @@ def_mappable_ops! { } impl scalar::Const { - fn try_decode_from_spv_imms(ty: scalar::Type, imms: &[spv::Imm]) -> Option { + // HACK(eddyb) this is not private so `spv::lower` can use it for `OpSwitch`. + pub(super) fn try_decode_from_spv_imms( + ty: scalar::Type, + imms: &[spv::Imm], + ) -> Option { // FIXME(eddyb) don't hardcode the 128-bit limitation, // but query `scalar::Const` somehow instead. if ty.bit_width() > 128 { @@ -198,7 +202,8 @@ impl scalar::Const { } } - fn encode_as_spv_imms(&self) -> impl Iterator { + // HACK(eddyb) this is not private so `spv::lift` can use it for `OpSwitch`. + pub(super) fn encode_as_spv_imms(&self) -> impl Iterator { let wk = &spec::Spec::get().well_known; let ty = self.ty(); diff --git a/src/spv/lift.rs b/src/spv/lift.rs index d84bf1cb..100a9f5c 100644 --- a/src/spv/lift.rs +++ b/src/spv/lift.rs @@ -1309,6 +1309,14 @@ impl LazyInst<'_, '_> { ids: [merge_label_id, continue_label_id].into_iter().collect(), }, Self::Terminator { parent_func, terminator } => { + let mut ids: SmallVec<[_; 4]> = terminator + .inputs + .iter() + .map(|&v| value_to_id(parent_func, v)) + .chain(terminator.targets.iter().map(|&target| parent_func.label_ids[&target])) + .collect(); + + // FIXME(eddyb) move some of this to `spv::canonical`. let inst = match &*terminator.kind { cfg::ControlInstKind::Unreachable => wk.OpUnreachable.into(), cfg::ControlInstKind::Return => { @@ -1327,23 +1335,21 @@ impl LazyInst<'_, '_> { cfg::ControlInstKind::SelectBranch(SelectionKind::BoolCond) => { wk.OpBranchConditional.into() } - cfg::ControlInstKind::SelectBranch(SelectionKind::SpvInst(inst)) => { - inst.clone() + cfg::ControlInstKind::SelectBranch(SelectionKind::Switch { case_consts }) => { + // HACK(eddyb) move the default case from last back to first. + let default_target = ids.pop().unwrap(); + ids.insert(1, default_target); + + spv::Inst { + opcode: wk.OpSwitch, + imms: case_consts + .iter() + .flat_map(|ct| ct.encode_as_spv_imms()) + .collect(), + } } }; - spv::InstWithIds { - without_ids: inst, - result_type_id: None, - result_id: None, - ids: terminator - .inputs - .iter() - .map(|&v| value_to_id(parent_func, v)) - .chain( - terminator.targets.iter().map(|&target| parent_func.label_ids[&target]), - ) - .collect(), - } + spv::InstWithIds { without_ids: inst, result_type_id: None, result_id: None, ids } } Self::OpFunctionEnd => spv::InstWithIds { without_ids: wk.OpFunctionEnd.into(), diff --git a/src/spv/lower.rs b/src/spv/lower.rs index 50169b27..2c9b7cab 100644 --- a/src/spv/lower.rs +++ b/src/spv/lower.rs @@ -3,11 +3,11 @@ use crate::spv::{self, spec}; // FIXME(eddyb) import more to avoid `crate::` everywhere. use crate::{ - cfg, print, AddrSpace, Attr, AttrSet, Const, ConstDef, ConstKind, Context, ControlNodeDef, - ControlNodeKind, ControlRegion, ControlRegionDef, ControlRegionInputDecl, DataInstDef, - DataInstFormDef, DataInstKind, DeclDef, Diag, EntityDefs, EntityList, ExportKey, Exportee, - Func, FuncDecl, FuncDefBody, FuncParam, FxIndexMap, GlobalVarDecl, GlobalVarDefBody, Import, - InternedStr, Module, SelectionKind, Type, TypeDef, TypeKind, TypeOrConst, Value, + cfg, print, scalar, AddrSpace, Attr, AttrSet, Const, ConstDef, ConstKind, Context, + ControlNodeDef, ControlNodeKind, ControlRegion, ControlRegionDef, ControlRegionInputDecl, + DataInstDef, DataInstFormDef, DataInstKind, DeclDef, Diag, EntityDefs, EntityList, ExportKey, + Exportee, Func, FuncDecl, FuncDefBody, FuncParam, FxIndexMap, GlobalVarDecl, GlobalVarDefBody, + Import, InternedStr, Module, SelectionKind, Type, TypeDef, TypeKind, TypeOrConst, Value, }; use rustc_hash::FxHashMap; use smallvec::SmallVec; @@ -85,6 +85,20 @@ fn invalid(reason: &str) -> io::Error { io::Error::new(io::ErrorKind::InvalidData, format!("malformed SPIR-V ({reason})")) } +fn invalid_factory_for_spv_inst( + inst: &spv::Inst, + result_id: Option, + ids: &[spv::Id], +) -> impl Fn(&str) -> io::Error { + let opcode = inst.opcode; + let first_id_operand = ids.first().copied(); + move |msg: &str| { + let result_prefix = result_id.map(|id| format!("%{id} = ")).unwrap_or_default(); + let operand_suffix = first_id_operand.map(|id| format!(" %{id} ...")).unwrap_or_default(); + invalid(&format!("in {result_prefix}{}{operand_suffix}: {msg}", opcode.name())) + } +} + // FIXME(eddyb) provide more information about any normalization that happened: // * stats about deduplication that occured through interning // * sets of unused global vars and functions (and types+consts only they use) @@ -195,7 +209,7 @@ impl Module { while let Some(mut inst) = spv_insts.next().transpose()? { let opcode = inst.opcode; - let invalid = |msg: &str| invalid(&format!("in {}: {}", opcode.name(), msg)); + let invalid = invalid_factory_for_spv_inst(&inst, inst.result_id, &inst.ids); // Handle line debuginfo early, as it doesn't have its own section, // but rather can go almost anywhere among globals and functions. @@ -861,7 +875,7 @@ impl Module { #[derive(Copy, Clone)] enum LocalIdDef { - Value(Value), + Value(Type, Value), BlockLabel(ControlRegion), } @@ -889,6 +903,7 @@ impl Module { let IntraFuncInst { without_ids: spv::Inst { opcode, ref imms }, result_id, + result_type, .. } = *raw_inst; @@ -903,10 +918,10 @@ impl Module { DeclDef::Present(def) => def.body, }; - LocalIdDef::Value(Value::ControlRegionInput { - region: body, - input_idx: idx, - }) + LocalIdDef::Value( + result_type.unwrap(), + Value::ControlRegionInput { region: body, input_idx: idx }, + ) } else { let is_entry_block = !has_blocks; has_blocks = true; @@ -957,10 +972,13 @@ impl Module { .push(value_id); } - LocalIdDef::Value(Value::ControlRegionInput { - region: current_block, - input_idx: phi_idx, - }) + LocalIdDef::Value( + result_type.unwrap(), + Value::ControlRegionInput { + region: current_block, + input_idx: phi_idx, + }, + ) } else { // HACK(eddyb) can't get a `DataInst` without // defining it (as a dummy) first. @@ -974,7 +992,7 @@ impl Module { } .into(), ); - LocalIdDef::Value(Value::DataInstOutput(inst)) + LocalIdDef::Value(result_type.unwrap(), Value::DataInstOutput(inst)) } }; local_id_defs.insert(id, local_id_def); @@ -1023,50 +1041,52 @@ impl Module { ref ids, } = *raw_inst; - let invalid = |msg: &str| invalid(&format!("in {}: {}", opcode.name(), msg)); + let invalid = invalid_factory_for_spv_inst(&raw_inst.without_ids, result_id, ids); // FIXME(eddyb) find a more compact name and/or make this a method. // FIXME(eddyb) this returns `LocalIdDef` even for global values. - let lookup_global_or_local_id_for_data_or_control_inst_input = - |id| match id_defs.get(&id) { - Some(&IdDef::Const(ct)) => Ok(LocalIdDef::Value(Value::Const(ct))), - Some(id_def @ IdDef::Type(_)) => Err(invalid(&format!( - "unsupported use of {} as an operand for \ + let lookup_global_or_local_id_for_data_or_control_inst_input = |id| match id_defs + .get(&id) + { + Some(&IdDef::Const(ct)) => Ok(LocalIdDef::Value(cx[ct].ty, Value::Const(ct))), + Some(id_def @ IdDef::Type(_)) => Err(invalid(&format!( + "unsupported use of {} as an operand for \ an instruction in a function", - id_def.descr(&cx), - ))), - Some(id_def @ IdDef::Func(_)) => Err(invalid(&format!( - "unsupported use of {} outside `OpFunctionCall`", - id_def.descr(&cx), - ))), - Some(id_def @ IdDef::SpvDebugString(s)) => { - if opcode == wk.OpExtInst { - // HACK(eddyb) intern `OpString`s as `Const`s on - // the fly, as it's a less likely usage than the - // `OpLine` one. - let ct = cx.intern(ConstDef { - attrs: AttrSet::default(), - ty: cx.intern(TypeKind::SpvStringLiteralForExtInst), - kind: ConstKind::SpvStringLiteralForExtInst(*s), - }); - Ok(LocalIdDef::Value(Value::Const(ct))) - } else { - Err(invalid(&format!( - "unsupported use of {} outside `OpSource`, \ + id_def.descr(&cx), + ))), + Some(id_def @ IdDef::Func(_)) => Err(invalid(&format!( + "unsupported use of {} outside `OpFunctionCall`", + id_def.descr(&cx), + ))), + Some(id_def @ IdDef::SpvDebugString(s)) => { + if opcode == wk.OpExtInst { + // HACK(eddyb) intern `OpString`s as `Const`s on + // the fly, as it's a less likely usage than the + // `OpLine` one. + let ty = cx.intern(TypeKind::SpvStringLiteralForExtInst); + let ct = cx.intern(ConstDef { + attrs: AttrSet::default(), + ty, + kind: ConstKind::SpvStringLiteralForExtInst(*s), + }); + Ok(LocalIdDef::Value(ty, Value::Const(ct))) + } else { + Err(invalid(&format!( + "unsupported use of {} outside `OpSource`, \ `OpLine`, or `OpExtInst`", - id_def.descr(&cx), - ))) - } + id_def.descr(&cx), + ))) } - Some(id_def @ IdDef::SpvExtInstImport(_)) => Err(invalid(&format!( - "unsupported use of {} outside `OpExtInst`", - id_def.descr(&cx), - ))), - None => local_id_defs - .get(&id) - .copied() - .ok_or_else(|| invalid(&format!("undefined ID %{id}",))), - }; + } + Some(id_def @ IdDef::SpvExtInstImport(_)) => Err(invalid(&format!( + "unsupported use of {} outside `OpExtInst`", + id_def.descr(&cx), + ))), + None => local_id_defs + .get(&id) + .copied() + .ok_or_else(|| invalid(&format!("undefined ID %{id}",))), + }; if opcode == wk.OpFunctionParameter { if current_block_control_region_and_details.is_some() { @@ -1104,7 +1124,7 @@ impl Module { // to be able to have an entry in `local_id_defs`. let control_region = match local_id_defs[&result_id.unwrap()] { LocalIdDef::BlockLabel(control_region) => control_region, - LocalIdDef::Value(_) => unreachable!(), + LocalIdDef::Value(..) => unreachable!(), }; let current_block_details = &block_details[&control_region]; assert_eq!(current_block_details.label_id, result_id.unwrap()); @@ -1140,7 +1160,7 @@ impl Module { }; let phi_value_id_to_value = |phi_key: &PhiKey, id| { match lookup_global_or_local_id_for_data_or_control_inst_input(id)? { - LocalIdDef::Value(v) => Ok(v), + LocalIdDef::Value(_, v) => Ok(v), LocalIdDef::BlockLabel { .. } => Err(invalid(&format!( "unsupported use of block label as the value for {}", descr_phi_case(phi_key) @@ -1191,10 +1211,11 @@ impl Module { // Split the operands into value inputs (e.g. a branch's // condition or an `OpSwitch`'s selector) and target blocks. let mut inputs = SmallVec::new(); + let mut input_types = SmallVec::<[_; 2]>::new(); let mut targets = SmallVec::new(); for &id in ids { match lookup_global_or_local_id_for_data_or_control_inst_input(id)? { - LocalIdDef::Value(v) => { + LocalIdDef::Value(ty, v) => { if !targets.is_empty() { return Err(invalid( "out of order: value operand \ @@ -1202,6 +1223,7 @@ impl Module { )); } inputs.push(v); + input_types.push(ty); } LocalIdDef::BlockLabel(target) => { record_cfg_edge(target)?; @@ -1210,6 +1232,7 @@ impl Module { } } + // FIXME(eddyb) move some of this to `spv::canonical`. let kind = if opcode == wk.OpUnreachable { assert!(targets.is_empty() && inputs.is_empty()); cfg::ControlInstKind::Unreachable @@ -1227,9 +1250,62 @@ impl Module { assert_eq!((targets.len(), inputs.len()), (2, 1)); cfg::ControlInstKind::SelectBranch(SelectionKind::BoolCond) } else if opcode == wk.OpSwitch { - cfg::ControlInstKind::SelectBranch(SelectionKind::SpvInst( - raw_inst.without_ids.clone(), - )) + assert_eq!(inputs.len(), 1); + + // HACK(eddyb) `spv::read` has to "redundantly" validate + // that such a type is `OpTypeInt`/`OpTypeFloat`, but + // there is still a limitation when it comes to `scalar::Const`. + // FIXME(eddyb) don't hardcode the 128-bit limitation, + // but query `scalar::Const` somehow instead. + let scrutinee_type = input_types[0]; + let scrutinee_type = scrutinee_type + .as_scalar(&cx) + .filter(|ty| { + matches!(ty, scalar::Type::UInt(_) | scalar::Type::SInt(_)) + && ty.bit_width() <= 128 + }) + .ok_or_else(|| { + invalid( + &print::Plan::for_root( + &cx, + &Diag::err([ + "unsupported `OpSwitch` scrutinee type `".into(), + scrutinee_type.into(), + "`".into(), + ]) + .message, + ) + .pretty_print() + .to_string(), + ) + })?; + + // FIXME(eddyb) move some of this to `spv::canonical`. + let imm_words_per_case = + usize::try_from(scrutinee_type.bit_width().div_ceil(32)).unwrap(); + + // NOTE(eddyb) these sanity-checks are redundant with `spv::read`. + assert_eq!(imms.len() % imm_words_per_case, 0); + assert_eq!(targets.len(), 1 + imms.len() / imm_words_per_case); + + let case_consts = imms + .chunks(imm_words_per_case) + .map(|case_imms| { + scalar::Const::try_decode_from_spv_imms(scrutinee_type, case_imms) + .ok_or_else(|| { + invalid(&format!( + "invalid {}-bit `OpSwitch` case constant", + scrutinee_type.bit_width() + )) + }) + }) + .collect::>()?; + + // HACK(eddyb) move the default case from first to last. + let default_target = targets.remove(0); + targets.push(default_target); + + cfg::ControlInstKind::SelectBranch(SelectionKind::Switch { case_consts }) } else { return Err(invalid("unsupported control-flow instruction")); }; @@ -1274,7 +1350,7 @@ impl Module { let loop_merge_target = match lookup_global_or_local_id_for_data_or_control_inst_input(ids[0])? { - LocalIdDef::Value(_) => return Err(invalid("expected label ID")), + LocalIdDef::Value(..) => return Err(invalid("expected label ID")), LocalIdDef::BlockLabel(target) => target, }; @@ -1373,7 +1449,7 @@ impl Module { .map(|&id| { match lookup_global_or_local_id_for_data_or_control_inst_input(id)? { - LocalIdDef::Value(v) => Ok(v), + LocalIdDef::Value(_, v) => Ok(v), LocalIdDef::BlockLabel { .. } => Err(invalid( "unsupported use of block label as a value, \ in non-terminator instruction", @@ -1384,7 +1460,7 @@ impl Module { }; let inst = match result_id { Some(id) => match local_id_defs[&id] { - LocalIdDef::Value(Value::DataInstOutput(inst)) => { + LocalIdDef::Value(_, Value::DataInstOutput(inst)) => { // A dummy was defined earlier, to be able to // have an entry in `local_id_defs`. func_def_body.data_insts[inst] = data_inst_def.into(); diff --git a/src/spv/read.rs b/src/spv/read.rs index 5acf2930..c532b804 100644 --- a/src/spv/read.rs +++ b/src/spv/read.rs @@ -173,11 +173,8 @@ impl InstParser<'_> { .and_then(|id| self.known_ids.get(&id)) .ok_or(Error::MissingContextSensitiveLiteralType)?; - let extra_word_count = match *contextual_type { - KnownIdDef::TypeIntOrFloat(width) => { - // HACK(eddyb) `(width + 31) / 32 - 1` but without overflow. - (width.get() - 1) / 32 - } + let word_count = match *contextual_type { + KnownIdDef::TypeIntOrFloat(width) => width.get().div_ceil(32), KnownIdDef::Uncategorized { opcode, .. } => { return Err(Error::UnsupportedContextSensitiveLiteralType { type_opcode: opcode, @@ -185,11 +182,11 @@ impl InstParser<'_> { } }; - if extra_word_count == 0 { + if word_count == 1 { self.inst.imms.push(spv::Imm::Short(kind, word)); } else { self.inst.imms.push(spv::Imm::LongStart(kind, word)); - for _ in 0..extra_word_count { + for _ in 1..word_count { let word = self.words.next().ok_or(Error::NotEnoughWords)?; self.inst.imms.push(spv::Imm::LongCont(kind, word)); } diff --git a/src/transform.rs b/src/transform.rs index 6cb697a2..eef7932a 100644 --- a/src/transform.rs +++ b/src/transform.rs @@ -640,7 +640,7 @@ impl InnerInPlaceTransform for FuncAtMut<'_, ControlNode> { } } ControlNodeKind::Select { - kind: SelectionKind::BoolCond | SelectionKind::SpvInst(_), + kind: SelectionKind::BoolCond | SelectionKind::Switch { case_consts: _ }, scrutinee, cases: _, } => { @@ -747,7 +747,7 @@ impl InnerInPlaceTransform for cfg::ControlInst { | cfg::ControlInstKind::ExitInvocation(cfg::ExitInvocationKind::SpvInst(_)) | cfg::ControlInstKind::Branch | cfg::ControlInstKind::SelectBranch( - SelectionKind::BoolCond | SelectionKind::SpvInst(_), + SelectionKind::BoolCond | SelectionKind::Switch { case_consts: _ }, ) => {} } for v in inputs { diff --git a/src/visit.rs b/src/visit.rs index 1b5d9718..19a7a48b 100644 --- a/src/visit.rs +++ b/src/visit.rs @@ -475,7 +475,7 @@ impl<'a> FuncAt<'a, ControlNode> { } } ControlNodeKind::Select { - kind: SelectionKind::BoolCond | SelectionKind::SpvInst(_), + kind: SelectionKind::BoolCond | SelectionKind::Switch { case_consts: _ }, scrutinee, cases, } => { @@ -556,7 +556,7 @@ impl InnerVisit for cfg::ControlInst { | cfg::ControlInstKind::ExitInvocation(cfg::ExitInvocationKind::SpvInst(_)) | cfg::ControlInstKind::Branch | cfg::ControlInstKind::SelectBranch( - SelectionKind::BoolCond | SelectionKind::SpvInst(_), + SelectionKind::BoolCond | SelectionKind::Switch { case_consts: _ }, ) => {} } for v in inputs { From f61e300fc4a7d38e789500681e226b700681c321 Mon Sep 17 00:00:00 2001 From: Eduard-Mihai Burtescu Date: Mon, 20 Nov 2023 07:13:50 +0200 Subject: [PATCH 05/22] WIP: OpSpecConstantOp --- src/spv/print.rs | 41 ++++++++++++++++++++++++++++++++++++++++- src/spv/read.rs | 26 ++++++++++++++++++++++++++ src/spv/spec.rs | 1 + src/spv/write.rs | 27 +++++++++++++++++++++++++++ 4 files changed, 94 insertions(+), 1 deletion(-) diff --git a/src/spv/print.rs b/src/spv/print.rs index 1fc99dd6..714b064d 100644 --- a/src/spv/print.rs +++ b/src/spv/print.rs @@ -77,6 +77,9 @@ impl TokensForOperand { // FIXME(eddyb) keep a `&'static spec::Spec` if that can even speed up anything. struct OperandPrinter, ID, IDS: Iterator> { + // FIXME(eddyb) use a field like this to interpret `Opcode`/`OperandKind`, too. + wk: &'static spv::spec::WellKnown, + /// Input immediate operands to print from (may be grouped e.g. into literals). imms: iter::Peekable, @@ -123,7 +126,41 @@ impl, ID, IDS: Iterator> OperandPrint let def = kind.def(); assert!(matches!(def, spec::OperandKindDef::Literal { .. })); - let literal_token = if kind == spec::Spec::get().well_known.LiteralString { + let literal_token = if kind == self.wk.LiteralSpecConstantOpInteger { + assert_eq!(words.len(), 1); + let (_, inner_name, inner_def) = match u16::try_from(first_word) + .ok() + .and_then(spec::Opcode::try_from_u16_with_name_and_def) + { + Some(opcode_name_and_def) => opcode_name_and_def, + None => { + self.out.tokens.push(Token::Error(format!( + "/* {first_word} not a valid `OpSpecConstantOp` opcode */" + ))); + return; + } + }; + + // FIXME(eddyb) deduplicate this with `enumerant_params`. + self.out.tokens.push(Token::EnumerandName(inner_name)); + + let mut first = true; + for (inner_mode, inner_name_and_kind) in inner_def.all_operands_with_names() { + if inner_mode == spec::OperandMode::Optional && self.is_exhausted() { + break; + } + + self.out.tokens.push(Token::Punctuation(if first { "(" } else { ", " })); + first = false; + + let (inner_name, inner_kind) = inner_name_and_kind.name_and_kind(); + self.operand(inner_name, inner_kind); + } + if !first { + self.out.tokens.push(Token::Punctuation(")")); + } + return; + } else if kind == self.wk.LiteralString { // FIXME(eddyb) deduplicate with `spv::extract_literal_string`. let bytes: SmallVec<[u8; 64]> = words .into_iter() @@ -260,6 +297,7 @@ impl, ID, IDS: Iterator> OperandPrint /// an enumerand with parameters (which consumes more immediates). pub fn operand_from_imms(imms: impl IntoIterator) -> TokensForOperand { let mut printer = OperandPrinter { + wk: &spec::Spec::get().well_known, imms: imms.into_iter().peekable(), ids: iter::empty().peekable(), out: TokensForOperand::default(), @@ -282,6 +320,7 @@ pub fn inst_operands( ids: impl IntoIterator, ) -> impl Iterator> { OperandPrinter { + wk: &spec::Spec::get().well_known, imms: imms.into_iter().peekable(), ids: ids.into_iter().peekable(), out: TokensForOperand::default(), diff --git a/src/spv/read.rs b/src/spv/read.rs index c532b804..cfa57a1d 100644 --- a/src/spv/read.rs +++ b/src/spv/read.rs @@ -27,6 +27,9 @@ impl KnownIdDef { // FIXME(eddyb) keep a `&'static spec::Spec` if that can even speed up anything. struct InstParser<'a> { + // FIXME(eddyb) use a field like this to interpret `Opcode`/`OperandKind`, too. + wk: &'static spv::spec::WellKnown, + /// IDs defined so far in the module. known_ids: &'a FxHashMap, @@ -59,6 +62,9 @@ enum InstParseError { /// The type of a `LiteralContextDependentNumber` was not a supported type /// (one of either `OpTypeInt` or `OpTypeFloat`). UnsupportedContextSensitiveLiteralType { type_opcode: spec::Opcode }, + + /// Unsupported `OpSpecConstantOp` (`LiteralSpecConstantOpInteger`) opcode. + UnsupportedSpecConstantOpOpcode(u32), } impl InstParseError { @@ -93,6 +99,9 @@ impl InstParseError { Self::UnsupportedContextSensitiveLiteralType { type_opcode } => { format!("{} is not a supported literal type", type_opcode.name()).into() } + Self::UnsupportedSpecConstantOpOpcode(opcode) => { + format!("{opcode} is not a supported opcode (for `OpSpecConstantOp`)").into() + } } } } @@ -194,6 +203,22 @@ impl InstParser<'_> { } } + // HACK(eddyb) this isn't cleanly uniform because it's an odd special case. + if kind == self.wk.LiteralSpecConstantOpInteger { + // FIXME(eddyb) this partially duplicates the main instruction parsing. + let (_, _, inner_def) = u16::try_from(word) + .ok() + .and_then(spec::Opcode::try_from_u16_with_name_and_def) + .ok_or(Error::UnsupportedSpecConstantOpOpcode(word))?; + + for (inner_mode, inner_kind) in inner_def.all_operands() { + if inner_mode == spec::OperandMode::Optional && self.is_exhausted() { + break; + } + self.operand(inner_kind)?; + } + } + Ok(()) } @@ -317,6 +342,7 @@ impl Iterator for ModuleParser { } let parser = InstParser { + wk: &spec::Spec::get().well_known, known_ids: &self.known_ids, words: words[1..inst_len].iter().copied(), inst: spv::InstWithIds { diff --git a/src/spv/spec.rs b/src/spv/spec.rs index 7fe89260..61646c61 100644 --- a/src/spv/spec.rs +++ b/src/spv/spec.rs @@ -175,6 +175,7 @@ def_well_known! { LiteralExtInstInteger, LiteralString, LiteralContextDependentNumber, + LiteralSpecConstantOpInteger, ], // FIXME(eddyb) find a way to namespace these to avoid conflicts. addressing_model: u32 = [ diff --git a/src/spv/write.rs b/src/spv/write.rs index 0d0a9312..083dfea6 100644 --- a/src/spv/write.rs +++ b/src/spv/write.rs @@ -7,6 +7,9 @@ use std::{fs, io, iter, slice}; // FIXME(eddyb) keep a `&'static spec::Spec` if that can even speed up anything. struct OperandEmitter<'a> { + // FIXME(eddyb) use a field like this to interpret `Opcode`/`OperandKind`, too. + wk: &'static spv::spec::WellKnown, + /// Input immediate operands of an instruction. imms: iter::Copied>, @@ -32,6 +35,9 @@ enum OperandEmitError { /// Unsupported enumerand value. UnsupportedEnumerand(spec::OperandKind, u32), + + /// Unsupported `OpSpecConstantOp` (`LiteralSpecConstantOpInteger`) opcode. + UnsupportedSpecConstantOpOpcode(u32), } impl OperandEmitError { @@ -60,6 +66,9 @@ impl OperandEmitError { _ => unreachable!(), } } + Self::UnsupportedSpecConstantOpOpcode(opcode) => { + format!("{opcode} is not a supported opcode (for `OpSpecConstantOp`)").into() + } } } } @@ -140,6 +149,23 @@ impl OperandEmitter<'_> { } } + // HACK(eddyb) this isn't cleanly uniform because it's an odd special case. + if kind == self.wk.LiteralSpecConstantOpInteger { + // FIXME(eddyb) this partially duplicates the main instruction emission. + let &word = self.out.last().unwrap(); + let (_, _, inner_def) = u16::try_from(word) + .ok() + .and_then(spec::Opcode::try_from_u16_with_name_and_def) + .ok_or(Error::UnsupportedSpecConstantOpOpcode(word))?; + + for (inner_mode, inner_kind) in inner_def.all_operands() { + if inner_mode == spec::OperandMode::Optional && self.is_exhausted() { + break; + } + self.operand(inner_kind)?; + } + } + Ok(()) } @@ -221,6 +247,7 @@ impl ModuleEmitter { ); OperandEmitter { + wk: &spec::Spec::get().well_known, imms: inst.imms.iter().copied(), ids: inst.ids.iter().copied(), out: &mut self.words, From 44faddcf64a3832a52f2be2151439265508c56cf Mon Sep 17 00:00:00 2001 From: Eduard-Mihai Burtescu Date: Sun, 5 Nov 2023 01:07:21 +0200 Subject: [PATCH 06/22] Add `TypeKind::Vector`&`ConstKind::Scalar` for vector types&consts. --- src/lib.rs | 36 ++++++++++- src/print/mod.rs | 149 +++++++++++++++++++------------------------ src/qptr/layout.rs | 31 +++++---- src/spv/canonical.rs | 140 ++++++++++++++++++++++++++++------------ src/spv/lift.rs | 91 +++++++++++++++++--------- src/spv/lower.rs | 24 ++----- src/spv/spec.rs | 3 +- src/transform.rs | 2 + src/vector.rs | 123 +++++++++++++++++++++++++++++++++++ src/visit.rs | 10 ++- 10 files changed, 420 insertions(+), 189 deletions(-) create mode 100644 src/vector.rs diff --git a/src/lib.rs b/src/lib.rs index e767790f..d85e8cab 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -170,6 +170,7 @@ pub mod passes { pub mod qptr; pub mod scalar; pub mod spv; +pub mod vector; use smallvec::SmallVec; use std::borrow::Cow; @@ -471,6 +472,13 @@ pub enum TypeKind { #[from] Scalar(scalar::Type), + /// Vector (small array of [`scalar`]s) type, with some limitations on the + /// supported component counts (but all standard ones should be included). + /// + /// See also the [`vector`] module for more documentation and definitions. + #[from] + Vector(vector::Type), + /// "Quasi-pointer", an untyped pointer-like abstract scalar that can represent /// both memory locations (in any address space) and other kinds of locations /// (e.g. SPIR-V `OpVariable`s in non-memory "storage classes"). @@ -509,7 +517,7 @@ macro_rules! impl_intern_type_kind { })+ } } -impl_intern_type_kind!(TypeKind, scalar::Type); +impl_intern_type_kind!(TypeKind, scalar::Type, vector::Type); // HACK(eddyb) this is like `Either`, only used in `TypeKind::SpvInst`, // and only because SPIR-V type definitions can references both types and consts. @@ -527,6 +535,12 @@ impl Type { _ => None, } } + pub fn as_vector(self, cx: &Context) -> Option { + match cx[self].kind { + TypeKind::Vector(ty) => Some(ty), + _ => None, + } + } } /// Interned handle for a [`ConstDef`](crate::ConstDef) (a constant value). @@ -562,6 +576,18 @@ pub enum ConstKind { #[from] Scalar(scalar::Const), + /// Vector (small array of [`scalar`]s) constant, which must have + /// a type of [`TypeKind::Vector`] (of the same [`vector::Type`]). + /// + /// See also the [`vector`] module for more documentation and definitions. + // + // FIXME(eddyb) maybe document the 128-bit limitation inherited from `scalar::Const`? + // FIXME(eddyb) this technically makes the `vector::Type` redundant, could + // it get out of sync? (perhaps "forced canonicalization" could be used to + // enforce that interning simply doesn't allow such scenarios?). + #[from] + Vector(vector::Const), + PtrToGlobalVar(GlobalVar), // HACK(eddyb) this is a fallback case that should become increasingly rare @@ -592,7 +618,7 @@ macro_rules! impl_intern_const_kind { })+ } } -impl_intern_const_kind!(scalar::Const); +impl_intern_const_kind!(scalar::Const, vector::Const); // HACK(eddyb) on `Const` instead of `ConstDef` for ergonomics reasons. impl Const { @@ -602,6 +628,12 @@ impl Const { _ => None, } } + pub fn as_vector(self, cx: &Context) -> Option<&vector::Const> { + match &cx[self].kind { + ConstKind::Vector(ct) => Some(ct), + _ => None, + } + } } /// Declarations ([`GlobalVarDecl`], [`FuncDecl`]) can contain a full definition, diff --git a/src/print/mod.rs b/src/print/mod.rs index 59719995..33e9de4d 100644 --- a/src/print/mod.rs +++ b/src/print/mod.rs @@ -673,7 +673,6 @@ enum UseStyle { impl<'a> Printer<'a> { fn new(plan: &Plan<'a>) -> Self { let cx = plan.cx; - let wk = &spv::spec::Spec::get().well_known; // HACK(eddyb) move this elsewhere. enum SmallSet { @@ -813,21 +812,18 @@ impl<'a> Printer<'a> { CxInterned::Type(ty) => { let ty_def = &cx[ty]; - // FIXME(eddyb) remove the duplication between - // here and `TypeDef`'s `Print` impl. - let has_compact_print_or_is_leaf = match &ty_def.kind { - TypeKind::SpvInst { spv_inst, type_and_const_inputs } => { - spv_inst.opcode == wk.OpTypeVector - || type_and_const_inputs.is_empty() + let is_leaf = match &ty_def.kind { + TypeKind::SpvInst { type_and_const_inputs, .. } => { + type_and_const_inputs.is_empty() } TypeKind::Scalar(_) + | TypeKind::Vector(_) | TypeKind::QPtr | TypeKind::SpvStringLiteralForExtInst => true, }; - ty_def.attrs == AttrSet::default() - && has_compact_print_or_is_leaf + ty_def.attrs == AttrSet::default() && is_leaf } CxInterned::Const(ct) => { let ct_def = &cx[ct]; @@ -2360,70 +2356,43 @@ impl Print for TypeDef { let wk = &spv::spec::Spec::get().well_known; - // FIXME(eddyb) should this be done by lowering SPIR-V types to SPIR-T? let kw = |kw| printer.declarative_keyword_style().apply(kw).into(); - #[allow(irrefutable_let_patterns)] - let compact_def = if let &TypeKind::SpvInst { - spv_inst: spv::Inst { opcode, ref imms }, - ref type_and_const_inputs, - } = kind - { - if opcode == wk.OpTypeVector { - let (elem_ty, elem_count) = match (&imms[..], &type_and_const_inputs[..]) { - (&[spv::Imm::Short(_, elem_count)], &[TypeOrConst::Type(elem_ty)]) => { - (elem_ty, elem_count) - } - _ => unreachable!(), - }; - Some(pretty::Fragment::new([ - elem_ty.print(printer), - "×".into(), - printer.numeric_literal_style().apply(format!("{elem_count}")).into(), - ])) - } else { - None + // FIXME(eddyb) should this just be `fmt::Display` on `scalar::Type`? + let print_scalar = |ty: scalar::Type| { + let width = ty.bit_width(); + match ty { + scalar::Type::Bool => "bool".into(), + scalar::Type::SInt(_) => format!("s{width}"), + scalar::Type::UInt(_) => format!("u{width}"), + scalar::Type::Float(_) => format!("f{width}"), } - } else { - None }; AttrsAndDef { attrs: attrs.print(printer), - def_without_name: if let Some(def) = compact_def { - def - } else { - match kind { - TypeKind::Scalar(ty) => { - let width = ty.bit_width(); - kw(match ty { - scalar::Type::Bool => "bool".into(), - scalar::Type::SInt(_) => format!("s{width}"), - scalar::Type::UInt(_) => format!("u{width}"), - scalar::Type::Float(_) => format!("f{width}"), - }) - } - - // FIXME(eddyb) should this be shortened to `qtr`? - TypeKind::QPtr => printer.declarative_keyword_style().apply("qptr").into(), - - TypeKind::SpvInst { spv_inst, type_and_const_inputs } => printer - .pretty_spv_inst( - printer.spv_op_style(), - spv_inst.opcode, - &spv_inst.imms, - type_and_const_inputs.iter().map(|&ty_or_ct| match ty_or_ct { - TypeOrConst::Type(ty) => ty.print(printer), - TypeOrConst::Const(ct) => ct.print(printer), - }), - ), - TypeKind::SpvStringLiteralForExtInst => pretty::Fragment::new([ - printer.error_style().apply("type_of").into(), - "(".into(), - printer.pretty_spv_opcode(printer.spv_op_style(), wk.OpString), - ")".into(), - ]), - } + def_without_name: match kind { + &TypeKind::Scalar(ty) => kw(print_scalar(ty)), + &TypeKind::Vector(ty) => kw(format!("{}×{}", print_scalar(ty.elem), ty.elem_count)), + + // FIXME(eddyb) should this be shortened to `qtr`? + TypeKind::QPtr => printer.declarative_keyword_style().apply("qptr").into(), + + TypeKind::SpvInst { spv_inst, type_and_const_inputs } => printer.pretty_spv_inst( + printer.spv_op_style(), + spv_inst.opcode, + &spv_inst.imms, + type_and_const_inputs.iter().map(|&ty_or_ct| match ty_or_ct { + TypeOrConst::Type(ty) => ty.print(printer), + TypeOrConst::Const(ct) => ct.print(printer), + }), + ), + TypeKind::SpvStringLiteralForExtInst => pretty::Fragment::new([ + printer.error_style().apply("type_of").into(), + "(".into(), + printer.pretty_spv_opcode(printer.spv_op_style(), wk.OpString), + ")".into(), + ]), }, } } @@ -2438,14 +2407,11 @@ impl Print for ConstDef { let kw = |kw| printer.declarative_keyword_style().apply(kw).into(); - let def_without_name = match kind { - ConstKind::Undef => pretty::Fragment::new([ - printer.imperative_keyword_style().apply("undef").into(), - printer.pretty_type_ascription_suffix(*ty), - ]), - ConstKind::Scalar(scalar::Const::FALSE) => kw("false"), - ConstKind::Scalar(scalar::Const::TRUE) => kw("true"), - ConstKind::Scalar(ct) => { + // FIXME(eddyb) should this just a method on `scalar::Const` instead? + let print_scalar = |ct: scalar::Const, include_type_suffix: bool| match ct { + scalar::Const::FALSE => kw("false"), + scalar::Const::TRUE => kw("true"), + _ => { let ty = ct.ty(); let width = ty.bit_width(); let (maybe_printed_value, ty_prefix) = match ty { @@ -2492,17 +2458,19 @@ impl Print for ConstDef { }; match maybe_printed_value { Some(printed_value) => { - let literal_ty_suffix = pretty::Styles { - // HACK(eddyb) the exact type detracts from the value. - color_opacity: Some(0.4), - subscript: true, - ..printer.declarative_keyword_style() + let printed_value = printer.numeric_literal_style().apply(printed_value); + if include_type_suffix { + let literal_ty_suffix = pretty::Styles { + // HACK(eddyb) the exact type detracts from the value. + color_opacity: Some(0.4), + subscript: true, + ..printer.declarative_keyword_style() + } + .apply(format!("{ty_prefix}{width}")); + pretty::Fragment::new([printed_value, literal_ty_suffix]) + } else { + printed_value.into() } - .apply(format!("{ty_prefix}{width}")); - pretty::Fragment::new([ - printer.numeric_literal_style().apply(printed_value), - literal_ty_suffix, - ]) } // HACK(eddyb) fallback using the bitwise representation. None => pretty::Fragment::new([ @@ -2523,6 +2491,18 @@ impl Print for ConstDef { ]), } } + }; + + let def_without_name = match kind { + ConstKind::Undef => pretty::Fragment::new([ + printer.imperative_keyword_style().apply("undef").into(), + printer.pretty_type_ascription_suffix(*ty), + ]), + &ConstKind::Scalar(ct) => print_scalar(ct, true), + ConstKind::Vector(ct) => pretty::Fragment::new([ + ty.print(printer), + pretty::join_comma_sep("(", ct.elems().map(|elem| print_scalar(elem, false)), ")"), + ]), &ConstKind::PtrToGlobalVar(gv) => { pretty::Fragment::new(["&".into(), gv.print(printer)]) } @@ -3251,6 +3231,7 @@ impl Print for FuncAt<'_, DataInst> { if let Value::Const(ct) = v { match &printer.cx[ct].kind { ConstKind::Undef + | ConstKind::Vector(_) | ConstKind::PtrToGlobalVar(_) | ConstKind::SpvInst { .. } => {} diff --git a/src/qptr/layout.rs b/src/qptr/layout.rs index 49cd48ed..0617ddb7 100644 --- a/src/qptr/layout.rs +++ b/src/qptr/layout.rs @@ -335,6 +335,21 @@ impl<'a> LayoutCache<'a> { } TypeKind::Scalar(ty) => return Ok(scalar(ty.bit_width())), + TypeKind::Vector(ty) => { + let len = u32::from(ty.elem_count.get()); + return array( + cx.intern(ty.elem), + ArrayParams { + fixed_len: Some(len), + known_stride: None, + + // NOTE(eddyb) this is specifically Vulkan "base alignment". + min_legacy_align: 1, + legacy_align_multiplier: if len <= 2 { 2 } else { 4 }, + }, + ); + } + // FIXME(eddyb) treat `QPtr`s as scalars. TypeKind::QPtr => { return Err(LayoutError(Diag::bug( @@ -359,15 +374,7 @@ impl<'a> LayoutCache<'a> { // FIXME(eddyb) categorize `OpTypePointer` by storage class and split on // logical vs physical here. scalar_with_size_and_align(self.config.logical_ptr_size_align) - } else if [wk.OpTypeVector, wk.OpTypeMatrix].contains(&spv_inst.opcode) { - let len = short_imm_at(0); - let (min_legacy_align, legacy_align_multiplier) = if spv_inst.opcode == wk.OpTypeVector - { - // NOTE(eddyb) this is specifically Vulkan "base alignment". - (1, if len <= 2 { 2 } else { 4 }) - } else { - (self.config.min_aggregate_legacy_align, 1) - }; + } else if spv_inst.opcode == wk.OpTypeMatrix { // NOTE(eddyb) `RowMajor` is disallowed on `OpTypeStruct` members below. array( match type_and_const_inputs[..] { @@ -375,10 +382,10 @@ impl<'a> LayoutCache<'a> { _ => unreachable!(), }, ArrayParams { - fixed_len: Some(len), + fixed_len: Some(short_imm_at(0)), known_stride: None, - min_legacy_align, - legacy_align_multiplier, + min_legacy_align: self.config.min_aggregate_legacy_align, + legacy_align_multiplier: 1, }, )? } else if [wk.OpTypeArray, wk.OpTypeRuntimeArray].contains(&spv_inst.opcode) { diff --git a/src/spv/canonical.rs b/src/spv/canonical.rs index e99d21dd..d50384d4 100644 --- a/src/spv/canonical.rs +++ b/src/spv/canonical.rs @@ -8,8 +8,9 @@ // FIXME(eddyb) should interning attempts check/apply these canonicalizations? use crate::spv::{self, spec}; -use crate::{scalar, ConstKind, Context, DataInstKind, Type, TypeKind}; +use crate::{scalar, vector, Const, ConstKind, Context, DataInstKind, Type, TypeKind, TypeOrConst}; use lazy_static::lazy_static; +use smallvec::SmallVec; // FIXME(eddyb) these ones could maybe make use of build script generation. macro_rules! def_mappable_ops { @@ -65,6 +66,7 @@ def_mappable_ops! { OpTypeBool, OpTypeInt, OpTypeFloat, + OpTypeVector, } const { OpUndef, @@ -249,55 +251,86 @@ impl spv::Inst { // FIXME(eddyb) automate bidirectional mappings more (although the need // for conditional, i.e. "partial", mappings, adds a lot of complexity). - pub(super) fn as_canonical_type(&self) -> Option { + pub(super) fn as_canonical_type( + &self, + cx: &Context, + type_and_const_inputs: &[TypeOrConst], + ) -> Option { let Self { opcode, imms } = self; let (&opcode, imms) = (opcode, &imms[..]); let mo = MappableOps::get(); let int_width = || scalar::IntWidth::try_from_bits(self.int_or_float_type_bit_width()?); - match imms { - [] if opcode == mo.OpTypeBool => Some(scalar::Type::Bool.into()), - &[_, spv::Imm::Short(_, 0)] if opcode == mo.OpTypeInt => { + match (imms, type_and_const_inputs) { + ([], []) if opcode == mo.OpTypeBool => Some(scalar::Type::Bool.into()), + (&[_, spv::Imm::Short(_, 0)], []) if opcode == mo.OpTypeInt => { Some(scalar::Type::UInt(int_width()?).into()) } - &[_, spv::Imm::Short(_, 1)] if opcode == mo.OpTypeInt => { + (&[_, spv::Imm::Short(_, 1)], []) if opcode == mo.OpTypeInt => { Some(scalar::Type::SInt(int_width()?).into()) } - [_] if opcode == mo.OpTypeFloat => Some( + ([_], []) if opcode == mo.OpTypeFloat => Some( scalar::Type::Float(scalar::FloatWidth::try_from_bits( self.int_or_float_type_bit_width()?, )?) .into(), ), + (&[spv::Imm::Short(_, elem_count)], &[TypeOrConst::Type(elem_type)]) + if opcode == mo.OpTypeVector => + { + Some( + vector::Type { + elem: elem_type.as_scalar(cx)?, + elem_count: u8::try_from(elem_count).ok()?.try_into().ok()?, + } + .into(), + ) + } _ => None, } } - pub(super) fn from_canonical_type(type_kind: &TypeKind) -> Option { + pub(super) fn from_canonical_type( + cx: &Context, + type_kind: &TypeKind, + ) -> Option<(Self, SmallVec<[TypeOrConst; 2]>)> { let wk = &spec::Spec::get().well_known; let mo = MappableOps::get(); match type_kind { - &TypeKind::Scalar(ty) => match ty { - scalar::Type::Bool => Some(mo.OpTypeBool.into()), - scalar::Type::SInt(w) | scalar::Type::UInt(w) => Some(spv::Inst { - opcode: mo.OpTypeInt, - imms: [ - spv::Imm::Short(wk.LiteralInteger, w.bits()), - spv::Imm::Short( - wk.LiteralInteger, - matches!(ty, scalar::Type::SInt(_)) as u32, - ), - ] - .into_iter() - .collect(), - }), - scalar::Type::Float(w) => Some(spv::Inst { - opcode: mo.OpTypeFloat, - imms: [spv::Imm::Short(wk.LiteralInteger, w.bits())].into_iter().collect(), - }), - }, + &TypeKind::Scalar(ty) => Some(( + match ty { + scalar::Type::Bool => mo.OpTypeBool.into(), + scalar::Type::SInt(w) | scalar::Type::UInt(w) => spv::Inst { + opcode: mo.OpTypeInt, + imms: [ + spv::Imm::Short(wk.LiteralInteger, w.bits()), + spv::Imm::Short( + wk.LiteralInteger, + matches!(ty, scalar::Type::SInt(_)) as u32, + ), + ] + .into_iter() + .collect(), + }, + scalar::Type::Float(w) => spv::Inst { + opcode: mo.OpTypeFloat, + imms: [spv::Imm::Short(wk.LiteralInteger, w.bits())].into_iter().collect(), + }, + }, + [].into_iter().collect(), + )), + + TypeKind::Vector(ty) => Some(( + spv::Inst { + opcode: mo.OpTypeVector, + imms: [spv::Imm::Short(wk.LiteralInteger, ty.elem_count.get().into())] + .into_iter() + .collect(), + }, + [TypeOrConst::Type(cx.intern(ty.elem))].into_iter().collect(), + )), TypeKind::QPtr | TypeKind::SpvInst { .. } | TypeKind::SpvStringLiteralForExtInst => { None @@ -313,33 +346,60 @@ impl spv::Inst { // FIXME(eddyb) automate bidirectional mappings more (although the need // for conditional, i.e. "partial", mappings, adds a lot of complexity). - pub(super) fn as_canonical_const(&self, cx: &Context, ty: Type) -> Option { + pub(super) fn as_canonical_const( + &self, + cx: &Context, + ty: Type, + const_inputs: &[Const], + ) -> Option { let Self { opcode, imms } = self; let (&opcode, imms) = (opcode, &imms[..]); + let wk = &spec::Spec::get().well_known; let mo = MappableOps::get(); - match imms { - [] if opcode == mo.OpUndef => Some(ConstKind::Undef), - [] if opcode == mo.OpConstantFalse => Some(scalar::Const::FALSE.into()), - [] if opcode == mo.OpConstantTrue => Some(scalar::Const::TRUE.into()), - _ if opcode == mo.OpConstant => { + match (imms, const_inputs) { + ([], []) if opcode == mo.OpUndef => Some(ConstKind::Undef), + ([], []) if opcode == mo.OpConstantFalse => Some(scalar::Const::FALSE.into()), + ([], []) if opcode == mo.OpConstantTrue => Some(scalar::Const::TRUE.into()), + (_, []) if opcode == mo.OpConstant => { Some(scalar::Const::try_decode_from_spv_imms(ty.as_scalar(cx)?, imms)?.into()) } + _ if opcode == wk.OpConstantComposite => { + let ty = ty.as_vector(cx)?; + let elems = (const_inputs.len() == usize::from(ty.elem_count.get()) + && const_inputs.iter().all(|ct| ct.as_scalar(cx).is_some())) + .then(|| const_inputs.iter().map(|ct| *ct.as_scalar(cx).unwrap()))?; + Some(vector::Const::from_elems(ty, elems).into()) + } _ => None, } } - pub(super) fn from_canonical_const(const_kind: &ConstKind) -> Option { + pub(super) fn from_canonical_const( + cx: &Context, + const_kind: &ConstKind, + ) -> Option<(Self, SmallVec<[Const; 4]>)> { + let wk = &spec::Spec::get().well_known; let mo = MappableOps::get(); match const_kind { - ConstKind::Undef => Some(mo.OpUndef.into()), - ConstKind::Scalar(scalar::Const::FALSE) => Some(mo.OpConstantFalse.into()), - ConstKind::Scalar(scalar::Const::TRUE) => Some(mo.OpConstantTrue.into()), - ConstKind::Scalar(ct) => { - Some(spv::Inst { opcode: mo.OpConstant, imms: ct.encode_as_spv_imms().collect() }) - } + ConstKind::Undef => Some((mo.OpUndef.into(), [].into_iter().collect())), + &ConstKind::Scalar(ct) => Some(( + match ct { + scalar::Const::FALSE => mo.OpConstantFalse.into(), + scalar::Const::TRUE => mo.OpConstantTrue.into(), + _ => { + spv::Inst { opcode: mo.OpConstant, imms: ct.encode_as_spv_imms().collect() } + } + }, + [].into_iter().collect(), + )), + + ConstKind::Vector(ct) => Some(( + wk.OpConstantComposite.into(), + ct.elems().map(|elem| cx.intern(elem)).collect(), + )), ConstKind::PtrToGlobalVar(_) | ConstKind::SpvInst { .. } diff --git a/src/spv/lift.rs b/src/spv/lift.rs index 100a9f5c..28578d0e 100644 --- a/src/spv/lift.rs +++ b/src/spv/lift.rs @@ -121,8 +121,22 @@ impl Visitor<'_> for NeedsIdsCollector<'_> { return; } let ty_def = &self.cx[ty]; + + // HACK(eddyb) there isn't a great way to handle canonical types, but + // perhaps this result should be recorded in `self.globals`? + if let Some((_spv_inst, type_and_const_inputs)) = + spv::Inst::from_canonical_type(self.cx, &ty_def.kind) + { + for ty_or_ct in type_and_const_inputs { + match ty_or_ct { + TypeOrConst::Type(ty) => self.visit_type_use(ty), + TypeOrConst::Const(ct) => self.visit_const_use(ct), + } + } + } + match ty_def.kind { - TypeKind::Scalar(_) | TypeKind::SpvInst { .. } => {} + TypeKind::Scalar(_) | TypeKind::Vector(_) | TypeKind::SpvInst { .. } => {} // FIXME(eddyb) this should be a proper `Result`-based error instead, // and/or `spv::lift` should mutate the module for legalization. @@ -137,6 +151,7 @@ impl Visitor<'_> for NeedsIdsCollector<'_> { ); } } + self.visit_type_def(ty_def); self.globals.insert(global); } @@ -146,9 +161,21 @@ impl Visitor<'_> for NeedsIdsCollector<'_> { return; } let ct_def = &self.cx[ct]; + + // HACK(eddyb) there isn't a great way to handle canonical consts, but + // perhaps this result should be recorded in `self.globals`? + if let Some((_spv_inst, const_inputs)) = + spv::Inst::from_canonical_const(self.cx, &ct_def.kind) + { + for ct in const_inputs { + self.visit_const_use(ct); + } + } + match ct_def.kind { ConstKind::Undef | ConstKind::Scalar(_) + | ConstKind::Vector(_) | ConstKind::PtrToGlobalVar(_) | ConstKind::SpvInst { .. } => { self.visit_const_def(ct_def); @@ -1030,9 +1057,10 @@ impl LazyInst<'_, '_> { (gv_decl.attrs, import) } - ConstKind::Undef | ConstKind::Scalar(_) | ConstKind::SpvInst { .. } => { - (ct_def.attrs, None) - } + ConstKind::Undef + | ConstKind::Scalar(_) + | ConstKind::Vector(_) + | ConstKind::SpvInst { .. } => (ct_def.attrs, None), // Not inserted into `globals` while visiting. ConstKind::SpvStringLiteralForExtInst(_) => unreachable!(), @@ -1100,19 +1128,16 @@ impl LazyInst<'_, '_> { Self::Global(global) => match global { Global::Type(ty) => { let ty_def = &cx[ty]; - match spv::Inst::from_canonical_type(&ty_def.kind).ok_or(&ty_def.kind) { - Ok(spv_inst) => spv::InstWithIds { - without_ids: spv_inst, - result_type_id: None, - result_id, - ids: [].into_iter().collect(), - }, - - Err(TypeKind::Scalar(_)) => { + match spv::Inst::from_canonical_type(cx, &ty_def.kind) + .as_ref() + .ok_or(&ty_def.kind) + { + Err(TypeKind::Scalar(_) | TypeKind::Vector(_)) => { unreachable!("should've been handled as canonical") } - Err(TypeKind::SpvInst { spv_inst, type_and_const_inputs }) => { + Ok((spv_inst, type_and_const_inputs)) + | Err(TypeKind::SpvInst { spv_inst, type_and_const_inputs }) => { spv::InstWithIds { without_ids: spv_inst.clone(), result_type_id: None, @@ -1137,15 +1162,32 @@ impl LazyInst<'_, '_> { } Global::Const(ct) => { let ct_def = &cx[ct]; - match spv::Inst::from_canonical_const(&ct_def.kind).ok_or(&ct_def.kind) { - Ok(spv_inst) => spv::InstWithIds { + match spv::Inst::from_canonical_const(cx, &ct_def.kind).ok_or(&ct_def.kind) { + // FIXME(eddyb) this duplicates the `ConstKind::SpvInst` + // case, only due to an inability to pattern-match `Rc`. + Ok((spv_inst, const_inputs)) => spv::InstWithIds { without_ids: spv_inst, result_type_id: Some(ids.globals[&Global::Type(ct_def.ty)]), result_id, - ids: [].into_iter().collect(), + ids: const_inputs + .iter() + .map(|&ct| ids.globals[&Global::Const(ct)]) + .collect(), }, + Err(ConstKind::SpvInst { spv_inst_and_const_inputs }) => { + let (spv_inst, const_inputs) = &**spv_inst_and_const_inputs; + spv::InstWithIds { + without_ids: spv_inst.clone(), + result_type_id: Some(ids.globals[&Global::Type(ct_def.ty)]), + result_id, + ids: const_inputs + .iter() + .map(|&ct| ids.globals[&Global::Const(ct)]) + .collect(), + } + } - Err(ConstKind::Undef | ConstKind::Scalar(_)) => { + Err(ConstKind::Undef | ConstKind::Scalar(_) | ConstKind::Vector(_)) => { unreachable!("should've been handled as canonical") } @@ -1182,19 +1224,6 @@ impl LazyInst<'_, '_> { } } - Err(ConstKind::SpvInst { spv_inst_and_const_inputs }) => { - let (spv_inst, const_inputs) = &**spv_inst_and_const_inputs; - spv::InstWithIds { - without_ids: spv_inst.clone(), - result_type_id: Some(ids.globals[&Global::Type(ct_def.ty)]), - result_id, - ids: const_inputs - .iter() - .map(|&ct| ids.globals[&Global::Const(ct)]) - .collect(), - } - } - // Not inserted into `globals` while visiting. Err(ConstKind::SpvStringLiteralForExtInst(_)) => unreachable!(), } diff --git a/src/spv/lower.rs b/src/spv/lower.rs index 2c9b7cab..b7a6077c 100644 --- a/src/spv/lower.rs +++ b/src/spv/lower.rs @@ -589,15 +589,9 @@ impl Module { let ty = cx.intern(TypeDef { attrs: mem::take(&mut attrs), - kind: match inst.as_canonical_type() { - Some(type_kind) => { - assert_eq!(type_and_const_inputs.len(), 0); - type_kind - } - None => { - TypeKind::SpvInst { spv_inst: inst.without_ids, type_and_const_inputs } - } - }, + kind: inst.as_canonical_type(&cx, &type_and_const_inputs).unwrap_or( + TypeKind::SpvInst { spv_inst: inst.without_ids, type_and_const_inputs }, + ), }); id_defs.insert(id, IdDef::Type(ty)); @@ -626,15 +620,11 @@ impl Module { let ct = cx.intern(ConstDef { attrs: mem::take(&mut attrs), ty, - kind: match inst.as_canonical_const(&cx, ty) { - Some(const_kind) => { - assert_eq!(const_inputs.len(), 0); - const_kind - } - None => ConstKind::SpvInst { + kind: inst.as_canonical_const(&cx, ty, &const_inputs).unwrap_or_else(|| { + ConstKind::SpvInst { spv_inst_and_const_inputs: Rc::new((inst.without_ids, const_inputs)), - }, - }, + } + }), }); id_defs.insert(id, IdDef::Const(ct)); diff --git a/src/spv/spec.rs b/src/spv/spec.rs index 61646c61..5e5b2398 100644 --- a/src/spv/spec.rs +++ b/src/spv/spec.rs @@ -117,7 +117,6 @@ def_well_known! { OpNoLine, OpTypeVoid, - OpTypeVector, OpTypeMatrix, OpTypeArray, OpTypeRuntimeArray, @@ -130,6 +129,8 @@ def_well_known! { OpTypeSampledImage, OpTypeAccelerationStructureKHR, + OpConstantComposite, + OpVariable, OpFunction, diff --git a/src/transform.rs b/src/transform.rs index eef7932a..e7cf80c3 100644 --- a/src/transform.rs +++ b/src/transform.rs @@ -425,6 +425,7 @@ impl InnerTransform for TypeDef { attrs -> transformer.transform_attr_set_use(*attrs), kind -> match kind { TypeKind::Scalar(_) + | TypeKind::Vector(_) | TypeKind::QPtr | TypeKind::SpvStringLiteralForExtInst => Transformed::Unchanged, @@ -461,6 +462,7 @@ impl InnerTransform for ConstDef { kind -> match kind { ConstKind::Undef | ConstKind::Scalar(_) + | ConstKind::Vector(_) | ConstKind::SpvStringLiteralForExtInst(_) => Transformed::Unchanged, ConstKind::PtrToGlobalVar(gv) => transform!({ diff --git a/src/vector.rs b/src/vector.rs new file mode 100644 index 00000000..0b9c2de8 --- /dev/null +++ b/src/vector.rs @@ -0,0 +1,123 @@ +//! Vector types (small arrays of [`scalar`](crate::scalar)s) and associated functionality. +//! +//! **Note**: these are similar to SIMD types in other IRs, but SPIR-V often uses +//! its `OpTypeVector` to represent geometrical vectors, colors, etc. without any +//! expectation of SIMD execution (which most GPU execution models use implicitly, +//! i.e. one non-uniform scalar becomes a hardware SIMD vector, while a high-level +//! "vector" of N "lanes", becomes N separate hardware SIMD vectors). + +use crate::scalar; +use smallvec::SmallVec; +use std::num::NonZeroU8; +use std::rc::Rc; + +// FIXME(eddyb) this entire module shorthands "element" as "elem", is that good? + +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +pub struct Type { + pub elem: scalar::Type, + // FIXME(eddyb) maybe wrap this in a type that abstracts away the encoding? + pub elem_count: NonZeroU8, +} + +// FIXME(eddyb) document the 128-bit limitations inherited from `scalar::Const`. +#[derive(Clone, PartialEq, Eq, Hash)] +pub struct Const(ConstRepr); + +// HACK(eddyb) `#[repr(packed)]` not allowed on `enum`s themselves. +#[repr(packed)] +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +struct Packed(T); + +// FIXME(eddyb) maybe build an abstraction for "N-dimensional" bit arrays? +#[derive(Clone, PartialEq, Eq, Hash)] +#[repr(u8)] +enum ConstRepr { + // HACK(eddyb) `(Type, u128)` would waste almost half its size on padding, and + // packing will only impact accessing the bits, while allowing e.g. being + // wrapped in an outer `enum`, before reaching the same size as `(u128, u128)`. + Inline(Type, Packed), + + // HACK(eddyb) this does raise the alignment, but the size and alignment are + // kept at one pointer (so likely half of `u128`) - `Packed>` is sadly + // not an option because `#[derive(...)]` + `#[repr(packed)]` often requires + // `Copy` in order to be able to safely take references (to a copy of a field). + Boxed(Type, Rc>), +} + +impl Const { + pub const fn ty(&self) -> Type { + match self.0 { + ConstRepr::Inline(ty, _) | ConstRepr::Boxed(ty, _) => ty, + } + } + + pub fn from_elems(ty: Type, elems: impl IntoIterator) -> Const { + let elem_width = ty.elem.bit_width(); + assert!(elem_width <= 128); + + let expected_elem_count = u32::from(ty.elem_count.get()); + + let num_limbs = elem_width.checked_mul(expected_elem_count).unwrap().div_ceil(128); + assert_ne!(num_limbs, 0); + let mut limbs = SmallVec::<[u128; 1]>::from_elem(0, usize::try_from(num_limbs).unwrap()); + + let mut found_elem_count = 0; + for ct in elems { + let i: u32 = found_elem_count; + found_elem_count = found_elem_count.checked_add(1).unwrap(); + if i >= expected_elem_count { + continue; + } + + // FIXME(eddyb) get better names (perhaps from miri-like memory?). + let first_bit_idx = i.checked_mul(elem_width).unwrap(); + let limb_idx = first_bit_idx / 128; + let intra_limb_first_bit_idx = first_bit_idx % 128; + assert!(intra_limb_first_bit_idx + elem_width <= 128); + + limbs[usize::try_from(limb_idx).unwrap()] |= ct.bits() << intra_limb_first_bit_idx; + } + assert_eq!(found_elem_count, expected_elem_count); + + match limbs.into_inner() { + Ok([limb]) => Const(ConstRepr::Inline(ty, Packed(limb))), + Err(limbs) => Const(ConstRepr::Boxed(ty, Rc::new(limbs.into_vec()))), + } + } + + pub fn get_elem(&self, i: usize) -> Option { + let ty = self.ty(); + if i >= usize::from(ty.elem_count.get()) { + return None; + } + let i = u32::try_from(i).unwrap(); + let elem_width = ty.elem.bit_width(); + assert!(elem_width <= 128); + + // FIXME(eddyb) get better names (perhaps from miri-like memory?). + let first_bit_idx = i.checked_mul(elem_width).unwrap(); + let limb_idx = first_bit_idx / 128; + let intra_limb_first_bit_idx = first_bit_idx % 128; + assert!(intra_limb_first_bit_idx + elem_width <= 128); + + let limb = match &self.0 { + ConstRepr::Inline(_, limb) => { + assert_eq!(limb_idx, 0); + limb.0 + } + ConstRepr::Boxed(_, limbs) => limbs[usize::try_from(limb_idx).unwrap()], + }; + + Some(scalar::Const::from_bits( + ty.elem, + (limb >> intra_limb_first_bit_idx) & (!0 >> (128 - elem_width)), + )) + } + + pub fn elems(&self) -> impl Iterator + '_ { + let ty = self.ty(); + // FIXME(eddyb) there should be a more efficient way to do this. + (0..usize::from(ty.elem_count.get())).map(|i| self.get_elem(i).unwrap()) + } +} diff --git a/src/visit.rs b/src/visit.rs index 19a7a48b..e2c4e897 100644 --- a/src/visit.rs +++ b/src/visit.rs @@ -315,7 +315,10 @@ impl InnerVisit for TypeDef { visitor.visit_attr_set_use(*attrs); match kind { - TypeKind::Scalar(_) | TypeKind::QPtr | TypeKind::SpvStringLiteralForExtInst => {} + TypeKind::Scalar(_) + | TypeKind::Vector(_) + | TypeKind::QPtr + | TypeKind::SpvStringLiteralForExtInst => {} TypeKind::SpvInst { spv_inst: _, type_and_const_inputs } => { for &ty_or_ct in type_and_const_inputs { @@ -336,7 +339,10 @@ impl InnerVisit for ConstDef { visitor.visit_attr_set_use(*attrs); visitor.visit_type_use(*ty); match kind { - ConstKind::Undef | ConstKind::Scalar(_) | ConstKind::SpvStringLiteralForExtInst(_) => {} + ConstKind::Undef + | ConstKind::Scalar(_) + | ConstKind::Vector(_) + | ConstKind::SpvStringLiteralForExtInst(_) => {} &ConstKind::PtrToGlobalVar(gv) => visitor.visit_global_var_use(gv), ConstKind::SpvInst { spv_inst_and_const_inputs } => { From b349dea66660ca7280081ae80986cc9ccfda5558 Mon Sep 17 00:00:00 2001 From: Eduard-Mihai Burtescu Date: Sun, 5 Nov 2023 16:22:03 +0200 Subject: [PATCH 07/22] Add `DataInstKind::Vector` for pure vector ops. --- src/lib.rs | 6 +++ src/print/mod.rs | 65 +++++++++++++++++++++++++----- src/qptr/analyze.rs | 2 +- src/qptr/lift.rs | 2 +- src/qptr/lower.rs | 5 ++- src/spv/canonical.rs | 94 +++++++++++++++++++++++++++++++++++++++----- src/spv/lift.rs | 8 +++- src/spv/spec.rs | 6 +++ src/transform.rs | 1 + src/vector.rs | 57 +++++++++++++++++++++++++++ src/visit.rs | 1 + 11 files changed, 223 insertions(+), 24 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index d85e8cab..88aeeca8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -987,6 +987,12 @@ pub enum DataInstKind { #[from] Scalar(scalar::Op), + /// Vector (small array of [`scalar`]s) pure operations. + /// + /// See also the [`vector`] module for more documentation and definitions. + #[from] + Vector(vector::Op), + // FIXME(eddyb) try to split this into recursive and non-recursive calls, // to avoid needing special handling for recursion where it's impossible. FuncCall(Func), diff --git a/src/print/mod.rs b/src/print/mod.rs index 33e9de4d..6b68c8d0 100644 --- a/src/print/mod.rs +++ b/src/print/mod.rs @@ -24,8 +24,8 @@ use crate::print::multiversion::Versions; use crate::qptr::{self, QPtrAttr, QPtrMemUsage, QPtrMemUsageKind, QPtrOp, QPtrUsage}; use crate::visit::{InnerVisit, Visit, Visitor}; use crate::{ - cfg, scalar, spv, AddrSpace, Attr, AttrSet, AttrSetDef, Const, ConstDef, ConstKind, Context, - ControlNode, ControlNodeDef, ControlNodeKind, ControlNodeOutputDecl, ControlRegion, + cfg, scalar, spv, vector, AddrSpace, Attr, AttrSet, AttrSetDef, Const, ConstDef, ConstKind, + Context, ControlNode, ControlNodeDef, ControlNodeKind, ControlNodeOutputDecl, ControlRegion, ControlRegionDef, ControlRegionInputDecl, DataInst, DataInstDef, DataInstForm, DataInstFormDef, DataInstKind, DeclDef, Diag, DiagLevel, DiagMsgPart, EntityListIter, ExportKey, Exportee, Func, FuncDecl, FuncParam, FxIndexMap, FxIndexSet, GlobalVar, GlobalVarDecl, GlobalVarDefBody, @@ -2407,7 +2407,7 @@ impl Print for ConstDef { let kw = |kw| printer.declarative_keyword_style().apply(kw).into(); - // FIXME(eddyb) should this just a method on `scalar::Const` instead? + // FIXME(eddyb) should this be a method on `scalar::Const` instead? let print_scalar = |ct: scalar::Const, include_type_suffix: bool| match ct { scalar::Const::FALSE => kw("false"), scalar::Const::TRUE => kw("true"), @@ -3023,17 +3023,62 @@ impl Print for FuncAt<'_, DataInst> { let mut output_type_to_print = *output_type; + // FIXME(eddyb) should this be a method on `scalar::Op` instead? + let print_scalar = |op: scalar::Op| { + let name = op.name(); + let (namespace_prefix, name) = name.split_at(name.find('.').unwrap() + 1); + pretty::Fragment::new([ + printer + .demote_style_for_namespace_prefix(printer.declarative_keyword_style()) + .apply(namespace_prefix), + printer.declarative_keyword_style().apply(name), + ]) + }; + let def_without_type = match kind { - &DataInstKind::Scalar(op) => { - let name = op.name(); + &DataInstKind::Scalar(op) => pretty::Fragment::new([ + print_scalar(op), + pretty::join_comma_sep("(", inputs.iter().map(|v| v.print(printer)), ")"), + ]), + + &DataInstKind::Vector(op) => { + let (name, extra_last_input) = match op { + vector::Op::Distribute(_) => ("vec.distribute", None), + vector::Op::Reduce(op) => (op.name(), None), + vector::Op::Whole(op) => ( + op.name(), + match op { + vector::WholeOp::Extract { elem_idx } + | vector::WholeOp::Insert { elem_idx } => Some( + printer.numeric_literal_style().apply(elem_idx.to_string()).into(), + ), + vector::WholeOp::New + | vector::WholeOp::DynExtract + | vector::WholeOp::DynInsert + | vector::WholeOp::Mul => None, + }, + ), + }; let (namespace_prefix, name) = name.split_at(name.find('.').unwrap() + 1); - pretty::Fragment::new([ + let mut pretty_name = pretty::Fragment::new([ printer .demote_style_for_namespace_prefix(printer.declarative_keyword_style()) - .apply(namespace_prefix) - .into(), - printer.declarative_keyword_style().apply(name).into(), - pretty::join_comma_sep("(", inputs.iter().map(|v| v.print(printer)), ")"), + .apply(namespace_prefix), + printer.declarative_keyword_style().apply(name), + ]); + if let vector::Op::Distribute(op) = op { + pretty_name = pretty::Fragment::new([ + pretty_name, + pretty::join_comma_sep("(", [print_scalar(op)], ")"), + ]); + } + pretty::Fragment::new([ + pretty_name, + pretty::join_comma_sep( + "(", + inputs.iter().map(|v| v.print(printer)).chain(extra_last_input), + ")", + ), ]) } diff --git a/src/qptr/analyze.rs b/src/qptr/analyze.rs index 51e66351..45183c1c 100644 --- a/src/qptr/analyze.rs +++ b/src/qptr/analyze.rs @@ -906,7 +906,7 @@ impl<'a> InferUsage<'a> { }); }; match &data_inst_form_def.kind { - DataInstKind::Scalar(_) => {} + DataInstKind::Scalar(_) | DataInstKind::Vector(_) => {} &DataInstKind::FuncCall(callee) => { match self.infer_usage_in_func(module, callee) { diff --git a/src/qptr/lift.rs b/src/qptr/lift.rs index f21875c9..f624d060 100644 --- a/src/qptr/lift.rs +++ b/src/qptr/lift.rs @@ -404,7 +404,7 @@ impl LiftToSpvPtrInstsInFunc<'_> { Ok((addr_space, self.lifter.layout_of(pointee_type)?)) }; let replacement_data_inst_def = match &data_inst_form_def.kind { - DataInstKind::Scalar(_) => return Ok(Transformed::Unchanged), + DataInstKind::Scalar(_) | DataInstKind::Vector(_) => return Ok(Transformed::Unchanged), &DataInstKind::FuncCall(_callee) => { for &v in &data_inst_def.inputs { diff --git a/src/qptr/lower.rs b/src/qptr/lower.rs index 87fa70a4..dec482e6 100644 --- a/src/qptr/lower.rs +++ b/src/qptr/lower.rs @@ -616,7 +616,10 @@ impl LowerFromSpvPtrInstsInFunc<'_> { match data_inst_form_def.kind { // Known semantics, no need to preserve SPIR-V pointer information. - DataInstKind::Scalar(_) | DataInstKind::FuncCall(_) | DataInstKind::QPtr(_) => return, + DataInstKind::Scalar(_) + | DataInstKind::Vector(_) + | DataInstKind::FuncCall(_) + | DataInstKind::QPtr(_) => return, DataInstKind::SpvInst(_) | DataInstKind::SpvExtInst { .. } => {} } diff --git a/src/spv/canonical.rs b/src/spv/canonical.rs index d50384d4..ea679873 100644 --- a/src/spv/canonical.rs +++ b/src/spv/canonical.rs @@ -17,12 +17,14 @@ macro_rules! def_mappable_ops { ( type { $($ty_op:ident),+ $(,)? } const { $($ct_op:ident),+ $(,)? } + data_inst { $($di_op:ident),+ $(,)? } $($enum_path:path { $($variant_op:ident <=> $variant:ident$(($($variant_args:tt)*))?),+ $(,)? })* ) => { #[allow(non_snake_case)] struct MappableOps { $($ty_op: spec::Opcode,)+ $($ct_op: spec::Opcode,)+ + $($di_op: spec::Opcode,)+ $($($variant_op: spec::Opcode,)+)* } impl MappableOps { @@ -35,6 +37,7 @@ macro_rules! def_mappable_ops { MappableOps { $($ty_op: spv_spec.instructions.lookup(stringify!($ty_op)).unwrap(),)+ $($ct_op: spv_spec.instructions.lookup(stringify!($ct_op)).unwrap(),)+ + $($di_op: spv_spec.instructions.lookup(stringify!($di_op)).unwrap(),)+ $($($variant_op: spv_spec.instructions.lookup(stringify!($variant_op)).unwrap(),)+)* } }; @@ -74,6 +77,11 @@ def_mappable_ops! { OpConstantTrue, OpConstant, } + data_inst { + OpVectorExtractDynamic, + OpVectorInsertDynamic, + OpVectorTimesScalar, + } scalar::BoolUnOp { OpLogicalNot <=> Not, } @@ -164,6 +172,11 @@ def_mappable_ops! { OpFUnordLessThanEqual <=> CmpOrUnord(scalar::FloatCmp::Le), OpFUnordGreaterThanEqual <=> CmpOrUnord(scalar::FloatCmp::Ge), } + vector::ReduceOp { + OpDot <=> Dot, + OpAny <=> Any, + OpAll <=> All, + } } impl scalar::Const { @@ -424,16 +437,46 @@ impl spv::Inst { if let Some(op) = scalar_op { assert_eq!(imms.len(), 0); - // FIXME(eddyb) support vector versions of these ops as well. - if output_types.len() == op.output_count() - && output_types.iter().all(|ty| ty.as_scalar(cx).is_some()) - { - Some(op.into()) + let (_scalar_type, vec_elem_count) = (output_types.len() == op.output_count()) + .then(|| { + output_types.iter().map(|&ty| match cx[ty].kind { + TypeKind::Scalar(ty) => Some((ty, None)), + TypeKind::Vector(ty) => Some((ty.elem, Some(ty.elem_count))), + _ => None, + }) + }) + .and_then(|mut outputs| { + let first = outputs.next().unwrap()?; + outputs.all(|x| x == Some(first)).then_some(first) + })?; + + Some(if vec_elem_count.is_some() { + vector::Op::Distribute(op).into() } else { - None - } + op.into() + }) + } else if let Some(op) = vector::ReduceOp::try_from_opcode(opcode).map(vector::Op::from) { + assert_eq!(imms.len(), 0); + Some(op.into()) } else { - None + let wk = &spec::Spec::get().well_known; + let mo = MappableOps::get(); + + // FIXME(eddyb) automate this by supporting immediates in the macro. + let v_whole = |op| Some(vector::Op::Whole(op).into()); + match imms { + [] if opcode == wk.OpCompositeConstruct => v_whole(vector::WholeOp::New), + &[spv::Imm::Short(_, elem_idx)] if opcode == wk.OpCompositeExtract => { + v_whole(vector::WholeOp::Extract { elem_idx: elem_idx.try_into().ok()? }) + } + &[spv::Imm::Short(_, elem_idx)] if opcode == wk.OpCompositeInsert => { + v_whole(vector::WholeOp::Insert { elem_idx: elem_idx.try_into().ok()? }) + } + [] if opcode == mo.OpVectorExtractDynamic => v_whole(vector::WholeOp::DynExtract), + [] if opcode == mo.OpVectorInsertDynamic => v_whole(vector::WholeOp::DynInsert), + [] if opcode == mo.OpVectorTimesScalar => v_whole(vector::WholeOp::Mul), + _ => None, + } } } @@ -447,7 +490,40 @@ impl spv::Inst { scalar::Op::FloatUnary(op) => op.to_opcode().into(), scalar::Op::FloatBinary(op) => op.to_opcode().into(), }), - _ => None, + &DataInstKind::Vector(op) => Some(match op { + vector::Op::Distribute(op) => { + Self::from_canonical_data_inst_kind(&DataInstKind::Scalar(op)).unwrap() + } + vector::Op::Reduce(op) => op.to_opcode().into(), + vector::Op::Whole(op) => { + let wk = &spec::Spec::get().well_known; + let mo = MappableOps::get(); + + // FIXME(eddyb) automate this by supporting immediates in the macro. + match op { + vector::WholeOp::New => wk.OpCompositeConstruct.into(), + vector::WholeOp::Extract { elem_idx } => spv::Inst { + opcode: wk.OpCompositeExtract, + imms: [spv::Imm::Short(wk.LiteralInteger, elem_idx.into())] + .into_iter() + .collect(), + }, + vector::WholeOp::Insert { elem_idx } => spv::Inst { + opcode: wk.OpCompositeInsert, + imms: [spv::Imm::Short(wk.LiteralInteger, elem_idx.into())] + .into_iter() + .collect(), + }, + vector::WholeOp::DynExtract => mo.OpVectorExtractDynamic.into(), + vector::WholeOp::DynInsert => mo.OpVectorInsertDynamic.into(), + vector::WholeOp::Mul => mo.OpVectorTimesScalar.into(), + } + } + }), + DataInstKind::FuncCall(_) + | DataInstKind::QPtr(_) + | DataInstKind::SpvInst(..) + | DataInstKind::SpvExtInst { .. } => None, } } } diff --git a/src/spv/lift.rs b/src/spv/lift.rs index 28578d0e..88170778 100644 --- a/src/spv/lift.rs +++ b/src/spv/lift.rs @@ -254,7 +254,11 @@ impl Visitor<'_> for NeedsIdsCollector<'_> { unreachable!("`DataInstKind::QPtr` should be legalized away before lifting"); } - DataInstKind::Scalar(_) | DataInstKind::FuncCall(_) | DataInstKind::SpvInst(_) => {} + DataInstKind::Scalar(_) + | DataInstKind::Vector(_) + | DataInstKind::FuncCall(_) + | DataInstKind::SpvInst(_) => {} + DataInstKind::SpvExtInst { ext_set, .. } => { self.ext_inst_imports.insert(&self.cx[ext_set]); } @@ -1286,7 +1290,7 @@ impl LazyInst<'_, '_> { match spv::Inst::from_canonical_data_inst_kind(kind).ok_or(kind) { Ok(spv_inst) => (spv_inst, None), - Err(DataInstKind::Scalar(_)) => { + Err(DataInstKind::Scalar(_) | DataInstKind::Vector(_)) => { unreachable!("should've been handled as canonical") } diff --git a/src/spv/spec.rs b/src/spv/spec.rs index 5e5b2398..81ddb800 100644 --- a/src/spv/spec.rs +++ b/src/spv/spec.rs @@ -129,6 +129,7 @@ def_well_known! { OpTypeSampledImage, OpTypeAccelerationStructureKHR, + // FIXME(eddyb) hide these from code, lowering should handle most cases. OpConstantComposite, OpVariable, @@ -159,6 +160,11 @@ def_well_known! { OpPtrAccessChain, OpInBoundsPtrAccessChain, OpBitcast, + + // FIXME(eddyb) hide these from code, lowering should handle most cases. + OpCompositeInsert, + OpCompositeExtract, + OpCompositeConstruct, ], operand_kind: OperandKind = [ Capability, diff --git a/src/transform.rs b/src/transform.rs index e7cf80c3..053ecb60 100644 --- a/src/transform.rs +++ b/src/transform.rs @@ -724,6 +724,7 @@ impl InnerTransform for DataInstFormDef { | QPtrOp::Store => Transformed::Unchanged, }, DataInstKind::Scalar(_) + | DataInstKind::Vector(_) | DataInstKind::SpvInst(_) | DataInstKind::SpvExtInst { .. } => Transformed::Unchanged, }, diff --git a/src/vector.rs b/src/vector.rs index 0b9c2de8..0b1d42f5 100644 --- a/src/vector.rs +++ b/src/vector.rs @@ -121,3 +121,60 @@ impl Const { (0..usize::from(ty.elem_count.get())).map(|i| self.get_elem(i).unwrap()) } } + +/// Pure operations with vector inputs and/or outputs. +#[derive(Copy, Clone, PartialEq, Eq, Hash, derive_more::From)] +pub enum Op { + Distribute(scalar::Op), + Reduce(ReduceOp), + + // FIXME(eddyb) find a better name for this category of ops. + Whole(WholeOp), +} + +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +pub enum ReduceOp { + // FIXME(eddyb) also support all the new integer dot product instructions. + Dot, + // FIXME(eddyb) model these using their respective `BoolBinOp`s? + Any, + All, +} + +impl ReduceOp { + pub fn name(self) -> &'static str { + match self { + ReduceOp::Dot => "vec.dot", + ReduceOp::Any => "vec.any", + ReduceOp::All => "vec.all", + } + } +} + +// FIXME(eddyb) find a better name for this category of ops. +// FIXME(eddyb) also support `OpVectorShuffle`. +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +pub enum WholeOp { + // FIXME(eddyb) better name for this (pack? make? "construct" is too long). + New, + Extract { elem_idx: u8 }, + Insert { elem_idx: u8 }, + DynExtract, + DynInsert, + + // FIXME(eddyb) may need a better name to indicate "scalar product". + Mul, +} + +impl WholeOp { + pub fn name(self) -> &'static str { + match self { + WholeOp::New => "vec.new", + WholeOp::Extract { .. } => "vec.extract", + WholeOp::Insert { .. } => "vec.insert", + WholeOp::DynExtract => "vec.dyn_extract", + WholeOp::DynInsert => "vec.dyn_insert", + WholeOp::Mul => "vec.mul", + } + } +} diff --git a/src/visit.rs b/src/visit.rs index e2c4e897..5a54a74c 100644 --- a/src/visit.rs +++ b/src/visit.rs @@ -542,6 +542,7 @@ impl InnerVisit for DataInstFormDef { | QPtrOp::Store => {} }, DataInstKind::Scalar(_) + | DataInstKind::Vector(_) | DataInstKind::SpvInst(_) | DataInstKind::SpvExtInst { .. } => {} } From 74fde563424b7f299724f49eb3c6e337710634eb Mon Sep 17 00:00:00 2001 From: Eduard-Mihai Burtescu Date: Sun, 1 Oct 2023 17:02:02 +0300 Subject: [PATCH 08/22] qptr/analyze: fix some latent issues in merging, caused by ZSTs. --- src/qptr/analyze.rs | 31 ++++++++++++++++++++++++++++--- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/src/qptr/analyze.rs b/src/qptr/analyze.rs index 45183c1c..60cd5627 100644 --- a/src/qptr/analyze.rs +++ b/src/qptr/analyze.rs @@ -170,7 +170,18 @@ impl UsageMerger<'_> { // Decompose the "smaller" and/or "less strict" side (`b`) first. match b.kind { // `Unused`s are always ignored. - QPtrMemUsageKind::Unused => return MergeResult::ok(a), + QPtrMemUsageKind::Unused + if { + // HACK(eddyb) see similar comment below, but also the comment + // above is invalidated by this condition - the issue is that + // only an unused offset of `0` is a true noop, otherwise + // there is a dead `qptr.offset` instruction which still + // needs a field to reference. + b_offset_in_a == 0 + } => + { + return MergeResult::ok(a); + } QPtrMemUsageKind::OffsetBase(b_entries) if { @@ -397,12 +408,26 @@ impl UsageMerger<'_> { .range(( Bound::Unbounded, b.max_size.map_or(Bound::Unbounded, |b_max_size| { - Bound::Excluded(b_offset_in_a.checked_add(b_max_size).unwrap()) + // HACK(eddyb) the unconditional `insert` below, at + // `b_offset_in_a`, can overwrite an existing entry + // if the ZST case isn't correctly handled. + if b_max_size == 0 { + Bound::Included(b_offset_in_a) + } else { + Bound::Excluded(b_offset_in_a.checked_add(b_max_size).unwrap()) + } }), )) .rev() - .take_while(|(a_sub_offset, a_sub_usage)| { + .take_while(|&(&a_sub_offset, a_sub_usage)| { a_sub_usage.max_size.map_or(true, |a_sub_max_size| { + // HACK(eddyb) the unconditional `insert` below, at + // `b_offset_in_a`, can overwrite an existing entry + // if the ZST case isn't correctly handled. + if b.max_size == Some(0) && a_sub_offset == b_offset_in_a { + return true; + } + a_sub_offset.checked_add(a_sub_max_size).unwrap() > b_offset_in_a }) }); From d4b0eb0ed32e45968273c5108b6d85ef50c4828a Mon Sep 17 00:00:00 2001 From: Eduard-Mihai Burtescu Date: Wed, 8 Nov 2023 00:50:10 +0200 Subject: [PATCH 09/22] qptr/analyze: don't panic on oddball `Const` pointers. --- src/qptr/analyze.rs | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/src/qptr/analyze.rs b/src/qptr/analyze.rs index 60cd5627..94b2c253 100644 --- a/src/qptr/analyze.rs +++ b/src/qptr/analyze.rs @@ -901,8 +901,21 @@ impl<'a> InferUsage<'a> { ConstKind::PtrToGlobalVar(gv) => { this.global_var_usages.entry(gv).or_default() } - // FIXME(eddyb) may be relevant? - _ => unreachable!(), + // FIXME(eddyb) attach on the `Const` by replacing + // it with a copy that also has an extra attribute, + // or actually support by adding the usage attribute + // in the same manner (if it makes sense to do so). + _ => { + usage_or_err_attrs_to_attach.push(( + Value::DataInstOutput(data_inst), + Err(AnalysisError(Diag::bug([ + "unsupported pointer constant `".into(), + ct.into(), + "`".into(), + ]))), + )); + return; + } }, Value::ControlRegionInput { region, input_idx } if region == func_def_body.body => From 44cf1cdae1d48297563ff7f1e1955a379d49d8ba Mon Sep 17 00:00:00 2001 From: Eduard-Mihai Burtescu Date: Sun, 1 Oct 2023 17:02:35 +0300 Subject: [PATCH 10/22] qptr: add an immediate `offset` to `Load`/`Store` ops. --- src/print/mod.rs | 31 ++++- src/qptr/analyze.rs | 43 +++++- src/qptr/lift.rs | 309 ++++++++++++++++++++++---------------------- src/qptr/lower.rs | 44 ++++++- src/qptr/mod.rs | 15 ++- src/transform.rs | 4 +- src/visit.rs | 4 +- 7 files changed, 268 insertions(+), 182 deletions(-) diff --git a/src/print/mod.rs b/src/print/mod.rs index 6b68c8d0..7618cbe2 100644 --- a/src/print/mod.rs +++ b/src/print/mod.rs @@ -3095,6 +3095,7 @@ impl Print for FuncAt<'_, DataInst> { QPtrOp::FuncLocalVar(_) => (None, &inputs[..]), _ => (Some(inputs[0]), &inputs[1..]), }; + let mut qptr_input = qptr_input.map(|v| v.print(printer)); let (name, extra_inputs): (_, SmallVec<[_; 1]>) = match op { QPtrOp::FuncLocalVar(mem_layout) => { assert!(extra_inputs.len() <= 1); @@ -3189,12 +3190,32 @@ impl Print for FuncAt<'_, DataInst> { ) } - QPtrOp::Load => { + &QPtrOp::Load { offset } => { assert_eq!(extra_inputs.len(), 0); + if offset != 0 { + qptr_input = Some(pretty::Fragment::new([ + qptr_input.take().unwrap(), + if offset < 0 { " - " } else { " + " }.into(), + printer + .numeric_literal_style() + .apply(offset.abs().to_string()) + .into(), + ])); + } ("load", [].into_iter().collect()) } - QPtrOp::Store => { + &QPtrOp::Store { offset } => { assert_eq!(extra_inputs.len(), 1); + if offset != 0 { + qptr_input = Some(pretty::Fragment::new([ + qptr_input.take().unwrap(), + if offset < 0 { " - " } else { " + " }.into(), + printer + .numeric_literal_style() + .apply(offset.abs().to_string()) + .into(), + ])); + } ("store", [extra_inputs[0].print(printer)].into_iter().collect()) } }; @@ -3205,11 +3226,7 @@ impl Print for FuncAt<'_, DataInst> { .apply("qptr.") .into(), printer.declarative_keyword_style().apply(name).into(), - pretty::join_comma_sep( - "(", - qptr_input.map(|v| v.print(printer)).into_iter().chain(extra_inputs), - ")", - ), + pretty::join_comma_sep("(", qptr_input.into_iter().chain(extra_inputs), ")"), ]) } diff --git a/src/qptr/analyze.rs b/src/qptr/analyze.rs index 94b2c253..47160ad4 100644 --- a/src/qptr/analyze.rs +++ b/src/qptr/analyze.rs @@ -1147,10 +1147,14 @@ impl<'a> InferUsage<'a> { }), ); } - DataInstKind::QPtr(op @ (QPtrOp::Load | QPtrOp::Store)) => { + DataInstKind::QPtr( + op @ (QPtrOp::Load { offset } | QPtrOp::Store { offset }), + ) => { let (op_name, access_type) = match op { - QPtrOp::Load => ("Load", data_inst_form_def.output_type.unwrap()), - QPtrOp::Store => { + QPtrOp::Load { .. } => { + ("Load", data_inst_form_def.output_type.unwrap()) + } + QPtrOp::Store { .. } => { ("Store", func_at_inst.at(data_inst_def.inputs[1]).type_of(&cx)) } _ => unreachable!(), @@ -1162,7 +1166,7 @@ impl<'a> InferUsage<'a> { .layout_of(access_type) .map_err(|LayoutError(e)| AnalysisError(e)) .and_then(|layout| match layout { - TypeLayout::Handle(shapes::Handle::Opaque(ty)) => { + TypeLayout::Handle(shapes::Handle::Opaque(ty)) if *offset == 0 => { Ok(QPtrUsage::Handles(shapes::Handle::Opaque(ty))) } TypeLayout::Handle(shapes::Handle::Buffer(..)) => { @@ -1171,6 +1175,11 @@ impl<'a> InferUsage<'a> { ) .into()]))) } + TypeLayout::Handle(_) => { + Err(AnalysisError(Diag::bug([format!( + "{op_name} {{ offset: {offset} }}: cannot offset Handles" + ).into()]))) + } TypeLayout::HandleArray(..) => { Err(AnalysisError(Diag::bug([format!( "{op_name}: cannot access whole HandleArray" @@ -1186,9 +1195,33 @@ impl<'a> InferUsage<'a> { .into()]))) } TypeLayout::Concrete(concrete) => { - Ok(QPtrUsage::Memory(QPtrMemUsage { + let usage = QPtrMemUsage { max_size: Some(concrete.mem_layout.fixed_base.size), kind: QPtrMemUsageKind::DirectAccess(access_type), + }; + + // FIXME(eddyb) deduplicate this with + // `QPtrOp::Offset` above. + let offset = u32::try_from(*offset).ok().ok_or_else(|| { + AnalysisError(Diag::bug([format!("{op_name} {{ offset: {offset} }}: negative offset").into()])) + })?; + + if offset == 0 { + return Ok(QPtrUsage::Memory(usage)); + } + + Ok(QPtrUsage::Memory(QPtrMemUsage { + max_size: usage + .max_size + .map(|max_size| offset.checked_add(max_size).ok_or_else(|| { + AnalysisError(Diag::bug([format!("{op_name} {{ offset: {offset} }}: size overflow ({offset}+{max_size})").into()])) + })).transpose()?, + // FIXME(eddyb) allocating `Rc>` + // to represent the one-element case, seems + // quite wasteful when it's likely consumed. + kind: QPtrMemUsageKind::OffsetBase(Rc::new( + [(offset, usage)].into(), + )), })) } }), diff --git a/src/qptr/lift.rs b/src/qptr/lift.rs index f624d060..58dd4f52 100644 --- a/src/qptr/lift.rs +++ b/src/qptr/lift.rs @@ -547,83 +547,21 @@ impl LiftToSpvPtrInstsInFunc<'_> { &DataInstKind::QPtr(QPtrOp::Offset(offset)) => { let base_ptr = data_inst_def.inputs[0]; let (addr_space, layout) = type_of_val_as_spv_ptr_with_layout(base_ptr)?; - let mut layout = match layout { - TypeLayout::Handle(_) | TypeLayout::HandleArray(..) => { - return Err(LiftError(Diag::bug(["cannot offset Handles".into()]))); - } - TypeLayout::Concrete(mem_layout) => mem_layout, - }; - let mut offset = u32::try_from(offset) - .ok() - .ok_or_else(|| LiftError(Diag::bug(["negative offset".into()])))?; - - let mut access_chain_inputs: SmallVec<_> = [base_ptr].into_iter().collect(); - // FIXME(eddyb) deduplicate with access chain loop for Load/Store. - while offset > 0 { - let idx = { - // HACK(eddyb) supporting ZSTs would be a pain because - // they can "fit" in weird ways, e.g. given 3 offsets - // A, B, C (before/between/after a pair of fields), - // `B..B` is included in both `A..B` and `B..C`. - let allow_zst = false; - let offset_range = if allow_zst { - offset..offset - } else { - offset..offset.saturating_add(1) - }; - let mut component_indices = - layout.components.find_components_containing(offset_range); - match (component_indices.next(), component_indices.next()) { - (None, _) => { - // FIXME(eddyb) this could include the chosen indices, - // and maybe the current type and/or layout. - return Err(LiftError(Diag::bug([format!( - "offset {offset} not found in type layout, after {} access chain indices", - access_chain_inputs.len() - 1 - ).into()]))); - } - (Some(idx), Some(_)) => { - // FIXME(eddyb) !!! this can also be illegal overlap - if allow_zst { - return Err(LiftError(Diag::bug([ - "ambiguity due to ZSTs in type layout".into(), - ]))); - } - // HACK(eddyb) letting illegal overlap through - idx - } - (Some(idx), None) => idx, - } - }; - let idx_as_i32 = i32::try_from(idx).ok().ok_or_else(|| { - LiftError(Diag::bug([ - format!("{idx} not representable as a positive s32").into() - ])) - })?; - access_chain_inputs - .push(Value::Const(cx.intern(scalar::Const::from_u32(idx_as_i32 as u32)))); - - match &layout.components { - Components::Scalar => unreachable!(), - Components::Elements { stride, elem, .. } => { - offset %= stride.get(); - layout = elem.clone(); - } - Components::Fields { offsets, layouts } => { - offset -= offsets[idx]; - layout = layouts[idx].clone(); - } - } - } - - if access_chain_inputs.len() == 1 { + self.maybe_adjust_pointer_for_offset_or_access( + base_ptr, + addr_space, + layout.clone(), + offset, + None, + )? + .unwrap_or_else(|| { self.deferred_ptr_noops.insert( data_inst, DeferredPtrNoop { output_pointer: base_ptr, output_pointer_addr_space: addr_space, - output_pointee_layout: TypeLayout::Concrete(layout), + output_pointee_layout: layout, parent_block, }, ); @@ -637,18 +575,7 @@ impl LiftToSpvPtrInstsInFunc<'_> { }), ..data_inst_def.clone() } - } else { - DataInstDef { - attrs: data_inst_def.attrs, - form: cx.intern(DataInstFormDef { - kind: DataInstKind::SpvInst(wk.OpAccessChain.into()), - output_type: Some( - self.lifter.spv_ptr_type(addr_space, layout.original_type), - ), - }), - inputs: access_chain_inputs, - } - } + }) } DataInstKind::QPtr(QPtrOp::DynOffset { stride, index_bounds }) => { let base_ptr = data_inst_def.inputs[0]; @@ -677,7 +604,7 @@ impl LiftToSpvPtrInstsInFunc<'_> { } } - // FIXME(eddyb) deduplicate with `maybe_adjust_pointer_for_access`. + // FIXME(eddyb) deduplicate with `maybe_adjust_pointer_for_offset_or_access`. let idx = { // FIXME(eddyb) there might be a better way to // estimate a relevant offset range for the array, @@ -688,7 +615,7 @@ impl LiftToSpvPtrInstsInFunc<'_> { .and_then(|index_bounds| u32::try_from(index_bounds.end).ok()) .unwrap_or(0); let offset_range = - 0..min_expected_len.checked_add(stride.get()).unwrap_or(0); + 0..min_expected_len.checked_mul(stride.get()).unwrap_or(0); let mut component_indices = layout.components.find_components_containing(offset_range); match (component_indices.next(), component_indices.next()) { @@ -734,10 +661,10 @@ impl LiftToSpvPtrInstsInFunc<'_> { inputs: access_chain_inputs, } } - DataInstKind::QPtr(op @ (QPtrOp::Load | QPtrOp::Store)) => { + DataInstKind::QPtr(op @ (QPtrOp::Load { offset } | QPtrOp::Store { offset })) => { let (spv_opcode, access_type) = match op { - QPtrOp::Load => (wk.OpLoad, data_inst_form_def.output_type.unwrap()), - QPtrOp::Store => (wk.OpStore, type_of_val(data_inst_def.inputs[1])), + QPtrOp::Load { .. } => (wk.OpLoad, data_inst_form_def.output_type.unwrap()), + QPtrOp::Store { .. } => (wk.OpStore, type_of_val(data_inst_def.inputs[1])), _ => unreachable!(), }; @@ -746,11 +673,12 @@ impl LiftToSpvPtrInstsInFunc<'_> { let input_idx = 0; let ptr = data_inst_def.inputs[input_idx]; let (addr_space, pointee_layout) = type_of_val_as_spv_ptr_with_layout(ptr)?; - self.maybe_adjust_pointer_for_access( + self.maybe_adjust_pointer_for_offset_or_access( ptr, addr_space, pointee_layout, - access_type, + *offset, + Some(access_type), )? .map(|access_chain_data_inst_def| (input_idx, access_chain_data_inst_def)) .into_iter() @@ -814,11 +742,12 @@ impl LiftToSpvPtrInstsInFunc<'_> { type_of_val_as_spv_ptr_with_layout(input_ptr)?; if let Some(access_chain_data_inst_def) = self - .maybe_adjust_pointer_for_access( + .maybe_adjust_pointer_for_offset_or_access( input_ptr, input_ptr_addr_space, input_pointee_layout, - expected_pointee_type, + 0, + Some(expected_pointee_type), )? { to_spv_ptr_input_adjustments @@ -882,24 +811,38 @@ impl LiftToSpvPtrInstsInFunc<'_> { Ok(Transformed::Changed(replacement_data_inst_def)) } - /// If necessary, construct an `OpAccessChain` instruction to turn `ptr` - /// (pointing to a type with `pointee_layout`) into a pointer to `access_type` - /// (which can then be used with e.g. `OpLoad`/`OpStore`). + /// If necessary, construct an `OpAccessChain` instruction to offset `ptr` + /// (pointing to a type with `pointee_layout`) by `offset`, and (optionally) + /// turn it into a pointer to `access_type` (for e.g. `OpLoad`/`OpStore`). // - // FIXME(eddyb) customize errors, to tell apart Load/Store/ToSpvPtrInput. - fn maybe_adjust_pointer_for_access( + // FIXME(eddyb) customize errors, to tell apart Offset/Load/Store/ToSpvPtrInput. + fn maybe_adjust_pointer_for_offset_or_access( &self, ptr: Value, addr_space: AddrSpace, mut pointee_layout: TypeLayout, - access_type: Type, + offset: i32, + access_type: Option, ) -> Result, LiftError> { let wk = self.lifter.wk; - let access_layout = self.lifter.layout_of(access_type)?; + let mk_access_chain = |access_chain_inputs: SmallVec<_>, final_pointee_type| { + if access_chain_inputs.len() > 1 { + Some(DataInstDef { + attrs: Default::default(), + form: self.lifter.cx.intern(DataInstFormDef { + kind: DataInstKind::SpvInst(wk.OpAccessChain.into()), + output_type: Some(self.lifter.spv_ptr_type(addr_space, final_pointee_type)), + }), + inputs: access_chain_inputs, + }) + } else { + None + } + }; + + let access_layout = access_type.map(|ty| self.lifter.layout_of(ty)).transpose()?; - // The access type might be merely a prefix of the pointee type, - // requiring injecting an extra `OpAccessChain` to "dig in". let mut access_chain_inputs: SmallVec<_> = [ptr].into_iter().collect(); if let TypeLayout::HandleArray(handle, _) = pointee_layout { @@ -907,97 +850,149 @@ impl LiftToSpvPtrInstsInFunc<'_> { .push(Value::Const(self.lifter.cx.intern(scalar::Const::from_u32(0)))); pointee_layout = TypeLayout::Handle(handle); } - match (pointee_layout, access_layout) { + let (mut pointee_layout, access_layout) = match (pointee_layout, access_layout) { (TypeLayout::HandleArray(..), _) => unreachable!(), // All the illegal cases are here to keep the rest tidier. - (_, TypeLayout::Handle(shapes::Handle::Buffer(..))) => { + (_, Some(TypeLayout::Handle(shapes::Handle::Buffer(..)))) => { return Err(LiftError(Diag::bug(["cannot access whole Buffer".into()]))); } - (_, TypeLayout::HandleArray(..)) => { + (_, Some(TypeLayout::HandleArray(..))) => { return Err(LiftError(Diag::bug(["cannot access whole HandleArray".into()]))); } - (_, TypeLayout::Concrete(access_layout)) + (_, Some(TypeLayout::Concrete(access_layout))) if access_layout.mem_layout.dyn_unit_stride.is_some() => { return Err(LiftError(Diag::bug(["cannot access unsized type".into()]))); } + (TypeLayout::Handle(_), Some(_)) if offset != 0 => { + return Err(LiftError(Diag::bug(["cannot offset Handles for access".into()]))); + } + (TypeLayout::Handle(_), None) => { + // FIXME(eddyb) this disallows even a noop offset on a handle pointer. + return Err(LiftError(Diag::bug(["cannot offset Handles".into()]))); + } (TypeLayout::Handle(shapes::Handle::Buffer(..)), _) => { - return Err(LiftError(Diag::bug(["cannot access into Buffer".into()]))); + return Err(LiftError(Diag::bug(["cannot offset/access into Buffer".into()]))); } - (TypeLayout::Handle(_), TypeLayout::Concrete(_)) => { + (TypeLayout::Handle(_), Some(TypeLayout::Concrete(_))) => { return Err(LiftError(Diag::bug(["cannot access Handle as memory".into()]))); } - (TypeLayout::Concrete(_), TypeLayout::Handle(_)) => { + (TypeLayout::Concrete(_), Some(TypeLayout::Handle(_))) => { return Err(LiftError(Diag::bug(["cannot access memory as Handle".into()]))); } ( TypeLayout::Handle(shapes::Handle::Opaque(pointee_handle_type)), - TypeLayout::Handle(shapes::Handle::Opaque(access_handle_type)), + Some(TypeLayout::Handle(shapes::Handle::Opaque(access_handle_type))), ) => { + assert_eq!(offset, 0); + if pointee_handle_type != access_handle_type { return Err(LiftError(Diag::bug([ "(opaque handle) pointer vs access type mismatch".into(), ]))); } + + return Ok(mk_access_chain(access_chain_inputs, pointee_handle_type)); } - (TypeLayout::Concrete(mut pointee_layout), TypeLayout::Concrete(access_layout)) => { - // FIXME(eddyb) deduplicate with access chain loop for Offset. - while pointee_layout.original_type != access_layout.original_type { - let idx = { - let offset_range = 0..access_layout.mem_layout.fixed_base.size; - let mut component_indices = - pointee_layout.components.find_components_containing(offset_range); - match (component_indices.next(), component_indices.next()) { - (None, _) => { - return Err(LiftError(Diag::bug([ - "accessed type not found in pointee type layout".into(), - ]))); - } - // FIXME(eddyb) obsolete this case entirely, - // by removing stores of ZSTs, and replacing - // loads of ZSTs with `OpUndef` constants. - (Some(_), Some(_)) => { - return Err(LiftError(Diag::bug([ - "ambiguity due to ZSTs in pointee type layout".into(), - ]))); - } - (Some(idx), None) => idx, - } - }; + (TypeLayout::Concrete(pointee_layout), Some(TypeLayout::Concrete(access_layout))) => { + (pointee_layout, Some(access_layout)) + } + (TypeLayout::Concrete(pointee_layout), None) => (pointee_layout, None), + }; - let idx_as_i32 = i32::try_from(idx).ok().ok_or_else(|| { - LiftError(Diag::bug([ - format!("{idx} not representable as a positive s32").into() - ])) + let mut offset = u32::try_from(offset) + .ok() + .ok_or_else(|| LiftError(Diag::bug(["negative offset".into()])))?; + + // FIXME(eddyb) deduplicate with access chain loop for Offset. + loop { + let done = offset == 0 + && access_layout.as_ref().map_or(true, |access_layout| { + pointee_layout.original_type == access_layout.original_type + }); + if done { + break; + } + + let idx = { + let min_component_size = match &access_layout { + Some(access_layout) => access_layout.mem_layout.fixed_base.size, + None => { + // HACK(eddyb) supporting ZSTs would be a pain because + // they can "fit" in weird ways, e.g. given 3 offsets + // A, B, C (before/between/after a pair of fields), + // `B..B` is included in both `A..B` and `B..C`. + let allow_zst = false; + if allow_zst { 0 } else { 1 } + } + }; + + let offset_range = offset..offset.saturating_add(min_component_size); + let mut component_indices = + pointee_layout.components.find_components_containing(offset_range.clone()); + + let idx = component_indices + .next() + .or_else(|| { + // HACK(eddyb) when dealing with a lone ZST, the search can + // fail because it expects at least one byte, so we retry + // with an empty range instead. + // FIXME(eddyb) this can still fail if there's another + // component that ends where the ZST is, maybe we need + // some way to filter for specifically such a ZST. + if access_layout.is_none() && offset_range.len() == 1 { + component_indices = pointee_layout + .components + .find_components_containing(offset..offset); + component_indices.next() + } else { + None + } + }) + .ok_or_else(|| { + // FIXME(eddyb) this could include the chosen indices, + // and maybe the current type and/or layout. + LiftError(Diag::bug([format!( + "offsets {offset_range:?} not found in pointee type layout, \ + after {} access chain indices", + access_chain_inputs.len() - 1 + ) + .into()])) })?; - access_chain_inputs.push(Value::Const( - self.lifter.cx.intern(scalar::Const::from_u32(idx_as_i32 as u32)), - )); - pointee_layout = match &pointee_layout.components { - Components::Scalar => unreachable!(), - Components::Elements { elem, .. } => elem.clone(), - Components::Fields { layouts, .. } => layouts[idx].clone(), - }; + if component_indices.next().is_some() { + return Err(LiftError(Diag::bug([ + "ambiguity due to ZSTs in pointee type layout".into(), + ]))); + } + + idx + }; + + let idx_as_i32 = i32::try_from(idx).ok().ok_or_else(|| { + LiftError(Diag::bug([format!("{idx} not representable as a positive s32").into()])) + })?; + access_chain_inputs.push(Value::Const( + self.lifter.cx.intern(scalar::Const::from_u32(idx_as_i32 as u32)), + )); + + match &pointee_layout.components { + Components::Scalar => unreachable!(), + Components::Elements { stride, elem, .. } => { + offset %= stride.get(); + pointee_layout = elem.clone(); + } + Components::Fields { offsets, layouts } => { + offset -= offsets[idx]; + pointee_layout = layouts[idx].clone(); } } } - Ok(if access_chain_inputs.len() > 1 { - Some(DataInstDef { - attrs: Default::default(), - form: self.lifter.cx.intern(DataInstFormDef { - kind: DataInstKind::SpvInst(wk.OpAccessChain.into()), - output_type: Some(self.lifter.spv_ptr_type(addr_space, access_type)), - }), - inputs: access_chain_inputs, - }) - } else { - None - }) + Ok(mk_access_chain(access_chain_inputs, pointee_layout.original_type)) } /// Apply rewrites implied by `deferred_ptr_noops` to `values`. diff --git a/src/qptr/lower.rs b/src/qptr/lower.rs index dec482e6..a8ecab46 100644 --- a/src/qptr/lower.rs +++ b/src/qptr/lower.rs @@ -408,6 +408,20 @@ impl LowerFromSpvPtrInstsInFunc<'_> { _ => return Ok(Transformed::Unchanged), }; + // Map `ptr` to its base & offset, if it points to a `QPtrOp::Offset`. + let ptr_to_base_ptr_and_offset = |ptr| match ptr { + Value::DataInstOutput(ptr_inst) => { + let ptr_inst_def = func.at(ptr_inst).def(); + match cx[ptr_inst_def.form].kind { + DataInstKind::QPtr(QPtrOp::Offset(ptr_offset)) => { + Some((ptr_inst_def.inputs[0], ptr_offset)) + } + _ => None, + } + } + _ => None, + }; + let replacement_kind_and_inputs = if spv_inst.opcode == wk.OpVariable { assert!(data_inst_def.inputs.len() <= 1); let (_, var_data_type) = @@ -429,14 +443,25 @@ impl LowerFromSpvPtrInstsInFunc<'_> { return Ok(Transformed::Unchanged); } assert_eq!(data_inst_def.inputs.len(), 1); - (QPtrOp::Load.into(), data_inst_def.inputs.clone()) + + let ptr = data_inst_def.inputs[0]; + + let (ptr, offset) = ptr_to_base_ptr_and_offset(ptr).unwrap_or((ptr, 0)); + + (QPtrOp::Load { offset }.into(), [ptr].into_iter().collect()) } else if spv_inst.opcode == wk.OpStore { // FIXME(eddyb) support memory operands somehow. if !spv_inst.imms.is_empty() { return Ok(Transformed::Unchanged); } assert_eq!(data_inst_def.inputs.len(), 2); - (QPtrOp::Store.into(), data_inst_def.inputs.clone()) + + let ptr = data_inst_def.inputs[0]; + let value = data_inst_def.inputs[1]; + + let (ptr, offset) = ptr_to_base_ptr_and_offset(ptr).unwrap_or((ptr, 0)); + + (QPtrOp::Store { offset }.into(), [ptr, value].into_iter().collect()) } else if spv_inst.opcode == wk.OpArrayLength { let field_idx = match spv_inst.imms[..] { [spv::Imm::Short(_, field_idx)] => field_idx, @@ -527,14 +552,27 @@ impl LowerFromSpvPtrInstsInFunc<'_> { self.lowerer.layout_of(base_pointee_type)? }; + let mut ptr = base_ptr; let mut steps = self.try_lower_access_chain(access_chain_base_layout, &data_inst_def.inputs[1..])?; + + // Fold a previous `Offset` into an initial offset step, where possible. + if let Some(QPtrChainStep { op: QPtrOp::Offset(first_offset), dyn_idx: None }) = + steps.first_mut() + { + if let Some((ptr_base_ptr, ptr_offset)) = ptr_to_base_ptr_and_offset(ptr) { + if let Some(new_first_offset) = first_offset.checked_add(ptr_offset) { + ptr = ptr_base_ptr; + *first_offset = new_first_offset; + } + } + } + // HACK(eddyb) noop cases should probably not use any `DataInst`s at all, // but that would require the ability to replace all uses of a `Value`. let final_step = steps.pop().unwrap_or(QPtrChainStep { op: QPtrOp::Offset(0), dyn_idx: None }); - let mut ptr = base_ptr; for step in steps { let (kind, inputs) = step.into_data_inst_kind_and_inputs(ptr); let step_data_inst = func_at_data_inst.reborrow().data_insts.define( diff --git a/src/qptr/mod.rs b/src/qptr/mod.rs index a9e9c972..e31f82cc 100644 --- a/src/qptr/mod.rs +++ b/src/qptr/mod.rs @@ -193,16 +193,19 @@ pub enum QPtrOp { index_bounds: Option>, }, - /// Read a single value from a `QPtr` (`inputs[0]`). + /// Read a single value from a `QPtr` (`inputs[0]`) at `offset`. // // FIXME(eddyb) limit this to memory, and scalars, maybe vectors at most. - Load, + Load { + offset: i32, + }, - /// Write a single value (`inputs[1]`) to a `QPtr` (`inputs[0]`). + /// Write a single value (`inputs[1]`) to a `QPtr` (`inputs[0]`) at `offset`. // // FIXME(eddyb) limit this to memory, and scalars, maybe vectors at most. - Store, + Store { + offset: i32, + }, // - // FIXME(eddyb) implement more ops! at the very least copying! - // (and lowering could ignore pointercasts, I guess?) + // FIXME(eddyb) implement more ops (e.g. copies). } diff --git a/src/transform.rs b/src/transform.rs index 053ecb60..844751a0 100644 --- a/src/transform.rs +++ b/src/transform.rs @@ -720,8 +720,8 @@ impl InnerTransform for DataInstFormDef { | QPtrOp::BufferDynLen { .. } | QPtrOp::Offset(_) | QPtrOp::DynOffset { .. } - | QPtrOp::Load - | QPtrOp::Store => Transformed::Unchanged, + | QPtrOp::Load {..} + | QPtrOp::Store {..} => Transformed::Unchanged, }, DataInstKind::Scalar(_) | DataInstKind::Vector(_) diff --git a/src/visit.rs b/src/visit.rs index 5a54a74c..4cbde0b4 100644 --- a/src/visit.rs +++ b/src/visit.rs @@ -538,8 +538,8 @@ impl InnerVisit for DataInstFormDef { | QPtrOp::BufferDynLen { .. } | QPtrOp::Offset(_) | QPtrOp::DynOffset { .. } - | QPtrOp::Load - | QPtrOp::Store => {} + | QPtrOp::Load { .. } + | QPtrOp::Store { .. } => {} }, DataInstKind::Scalar(_) | DataInstKind::Vector(_) From 145051eb5913ba763ab50c75e46f0c369995456c Mon Sep 17 00:00:00 2001 From: Eduard-Mihai Burtescu Date: Sun, 1 Oct 2023 17:03:02 +0300 Subject: [PATCH 11/22] qptr/lower: more aggressively strip `Offset(0)` and remove unused instructions. --- src/qptr/lower.rs | 193 +++++++++++++++++++++++++++++++++++++++------- 1 file changed, 163 insertions(+), 30 deletions(-) diff --git a/src/qptr/lower.rs b/src/qptr/lower.rs index a8ecab46..2fe70c01 100644 --- a/src/qptr/lower.rs +++ b/src/qptr/lower.rs @@ -7,12 +7,15 @@ use crate::func_at::FuncAtMut; use crate::qptr::{shapes, QPtrAttr, QPtrOp}; use crate::transform::{InnerInPlaceTransform, Transformed, Transformer}; use crate::{ - spv, AddrSpace, AttrSet, AttrSetDef, Const, ConstDef, ConstKind, Context, ControlNode, - ControlNodeKind, DataInst, DataInstDef, DataInstForm, DataInstFormDef, DataInstKind, Diag, - FuncDecl, GlobalVarDecl, OrdAssertEq, Type, TypeKind, TypeOrConst, Value, + spv, AddrSpace, AttrSetDef, Const, ConstDef, ConstKind, Context, ControlNode, ControlNodeKind, + DataInst, DataInstDef, DataInstForm, DataInstFormDef, DataInstKind, DeclDef, Diag, + EntityOrientedDenseMap, FuncDecl, GlobalVarDecl, OrdAssertEq, Type, TypeKind, TypeOrConst, + Value, }; +use rustc_hash::FxHashMap; use smallvec::SmallVec; use std::cell::Cell; +use std::mem; use std::num::NonZeroU32; use std::rc::Rc; @@ -141,7 +144,13 @@ impl<'a> LowerFromSpvPtrs<'a> { // separately - so `LowerFromSpvPtrInstsInFunc` will leave all value defs // (including replaced instructions!) with unchanged `OpTypePointer` // types, that only `EraseSpvPtrs`, later, replaces with `QPtr`. - LowerFromSpvPtrInstsInFunc { lowerer: self }.in_place_transform_func_decl(func_decl); + LowerFromSpvPtrInstsInFunc { + lowerer: self, + data_inst_use_counts: Default::default(), + remove_if_dead_inst_and_parent_block: Default::default(), + noop_offsets_to_base_ptr: Default::default(), + } + .in_place_transform_func_decl(func_decl); EraseSpvPtrs { lowerer: self }.in_place_transform_func_decl(func_decl); } @@ -241,6 +250,19 @@ impl Transformer for EraseSpvPtrs<'_> { struct LowerFromSpvPtrInstsInFunc<'a> { lowerer: &'a LowerFromSpvPtrs<'a>, + + // FIXME(eddyb) consider removing this and just do a full second traversal. + data_inst_use_counts: EntityOrientedDenseMap, + + // HACK(eddyb) this acts as a "queue" for `qptr`-producing instructions, + // which may end up dead because they're unused (either unused originally, + // in SPIR-V, or because of offset folding). + remove_if_dead_inst_and_parent_block: Vec<(DataInst, ControlNode)>, + + // FIXME(eddyb) this is redundant with a few other things and only here + // because it needs to be available from `transform_value`, which doesn't + // have access to a `FuncAt` to look up anything. + noop_offsets_to_base_ptr: FxHashMap, } /// One `QPtr`->`QPtr` step used in the lowering of `Op*AccessChain`. @@ -386,7 +408,7 @@ impl LowerFromSpvPtrInstsInFunc<'_> { } fn try_lower_data_inst_def( - &self, + &mut self, mut func_at_data_inst: FuncAtMut<'_, DataInst>, parent_block: ControlNode, ) -> Result, LowerError> { @@ -400,7 +422,7 @@ impl LowerFromSpvPtrInstsInFunc<'_> { // FIXME(eddyb) is this a good convention? let func = func_at_data_inst_frozen.at(()); - let mut attrs = data_inst_def.attrs; + let attrs = data_inst_def.attrs; let DataInstFormDef { ref kind, output_type } = cx[data_inst_def.form]; let spv_inst = match kind { @@ -408,18 +430,25 @@ impl LowerFromSpvPtrInstsInFunc<'_> { _ => return Ok(Transformed::Unchanged), }; - // Map `ptr` to its base & offset, if it points to a `QPtrOp::Offset`. - let ptr_to_base_ptr_and_offset = |ptr| match ptr { - Value::DataInstOutput(ptr_inst) => { + // Flatten `QPtrOp::Offset`s behind `ptr` into a base pointer and offset. + let flatten_offsets = |mut ptr| { + let mut offset = 0; + while let Value::DataInstOutput(ptr_inst) = ptr { let ptr_inst_def = func.at(ptr_inst).def(); match cx[ptr_inst_def.form].kind { DataInstKind::QPtr(QPtrOp::Offset(ptr_offset)) => { - Some((ptr_inst_def.inputs[0], ptr_offset)) + match ptr_offset.checked_add(offset) { + Some(combined_offset) => { + ptr = ptr_inst_def.inputs[0]; + offset = combined_offset; + } + None => break, + } } - _ => None, + _ => break, } } - _ => None, + (ptr, offset) }; let replacement_kind_and_inputs = if spv_inst.opcode == wk.OpVariable { @@ -446,7 +475,7 @@ impl LowerFromSpvPtrInstsInFunc<'_> { let ptr = data_inst_def.inputs[0]; - let (ptr, offset) = ptr_to_base_ptr_and_offset(ptr).unwrap_or((ptr, 0)); + let (ptr, offset) = flatten_offsets(ptr); (QPtrOp::Load { offset }.into(), [ptr].into_iter().collect()) } else if spv_inst.opcode == wk.OpStore { @@ -459,7 +488,7 @@ impl LowerFromSpvPtrInstsInFunc<'_> { let ptr = data_inst_def.inputs[0]; let value = data_inst_def.inputs[1]; - let (ptr, offset) = ptr_to_base_ptr_and_offset(ptr).unwrap_or((ptr, 0)); + let (ptr, offset) = flatten_offsets(ptr); (QPtrOp::Store { offset }.into(), [ptr, value].into_iter().collect()) } else if spv_inst.opcode == wk.OpArrayLength { @@ -560,11 +589,10 @@ impl LowerFromSpvPtrInstsInFunc<'_> { if let Some(QPtrChainStep { op: QPtrOp::Offset(first_offset), dyn_idx: None }) = steps.first_mut() { - if let Some((ptr_base_ptr, ptr_offset)) = ptr_to_base_ptr_and_offset(ptr) { - if let Some(new_first_offset) = first_offset.checked_add(ptr_offset) { - ptr = ptr_base_ptr; - *first_offset = new_first_offset; - } + let (ptr_base_ptr, ptr_offset) = flatten_offsets(ptr); + if let Some(new_first_offset) = first_offset.checked_add(ptr_offset) { + ptr = ptr_base_ptr; + *first_offset = new_first_offset; } } @@ -598,6 +626,11 @@ impl LowerFromSpvPtrInstsInFunc<'_> { match &mut func.control_nodes[parent_block].kind { ControlNodeKind::Block { insts } => { insts.insert_before(step_data_inst, data_inst, func.data_insts); + + // HACK(eddyb) this tracking is kind of ad-hoc but should + // easily cover everything we care about for now. + self.remove_if_dead_inst_and_parent_block + .push((step_data_inst, parent_block)); } _ => unreachable!(), } @@ -611,15 +644,9 @@ impl LowerFromSpvPtrInstsInFunc<'_> { if self.lowerer.as_spv_ptr_type(func.at(input).type_of(cx)).is_some() && self.lowerer.as_spv_ptr_type(output_type.unwrap()).is_some() { - // HACK(eddyb) noop cases should not use any `DataInst`s at all, - // but that would require the ability to replace all uses of a `Value`. - let noop_step = QPtrChainStep { op: QPtrOp::Offset(0), dyn_idx: None }; - - // HACK(eddyb) since we're not removing the `DataInst` entirely, - // at least get rid of its attributes to clearly mark it as synthetic. - attrs = AttrSet::default(); - - noop_step.into_data_inst_kind_and_inputs(input) + // HACK(eddyb) this will end added to `noop_offsets_to_base_ptr`, + // which should replace all uses of this bitcast with its input. + (QPtrOp::Offset(0).into(), data_inst_def.inputs.clone()) } else { return Ok(Transformed::Unchanged); } @@ -696,9 +723,51 @@ impl LowerFromSpvPtrInstsInFunc<'_> { func_at_data_inst.def().attrs = cx.intern(attrs); } } + + // FIXME(eddyb) these are only this whacky because an `u32` is being + // encoded as `Option` for (dense) map entry reasons. + fn add_value_uses(&mut self, values: &[Value]) { + for &v in values { + if let Value::DataInstOutput(data_inst) = v { + let count = self.data_inst_use_counts.entry(data_inst); + *count = Some( + NonZeroU32::new(count.map_or(0, |c| c.get()).checked_add(1).unwrap()).unwrap(), + ); + } + } + } + fn remove_value_uses(&mut self, values: &[Value]) { + for &v in values { + if let Value::DataInstOutput(data_inst) = v { + let count = self.data_inst_use_counts.entry(data_inst); + *count = NonZeroU32::new(count.unwrap().get() - 1); + } + } + } } impl Transformer for LowerFromSpvPtrInstsInFunc<'_> { + // NOTE(eddyb) it's important that this only gets invoked on already lowered + // `Value`s, so we can rely on e.g. `noop_offsets_to_base_ptr` being filled. + fn transform_value_use(&mut self, v: &Value) -> Transformed { + let mut v = *v; + + let transformed = match v { + Value::DataInstOutput(inst) => self + .noop_offsets_to_base_ptr + .get(&inst) + .copied() + .map_or(Transformed::Unchanged, Transformed::Changed), + + _ => Transformed::Unchanged, + }; + + transformed.apply_to(&mut v); + self.add_value_uses(&[v]); + + transformed + } + // HACK(eddyb) while we want to transform `DataInstDef`s, we can't inject // adjacent instructions without access to the parent `ControlNodeKind::Block`, // and to fix this would likely require list nodes to carry some handle to @@ -711,14 +780,44 @@ impl Transformer for LowerFromSpvPtrInstsInFunc<'_> { &mut self, mut func_at_control_node: FuncAtMut<'_, ControlNode>, ) { - func_at_control_node.reborrow().inner_in_place_transform_with(self); - let control_node = func_at_control_node.position; if let ControlNodeKind::Block { insts } = func_at_control_node.reborrow().def().kind { let mut func_at_inst_iter = func_at_control_node.reborrow().at(insts).into_iter(); while let Some(mut func_at_inst) = func_at_inst_iter.next() { match self.try_lower_data_inst_def(func_at_inst.reborrow(), control_node) { Ok(Transformed::Changed(new_def)) => { + // HACK(eddyb) this tracking is kind of ad-hoc but should + // easily cover everything we care about for now. + if let DataInstKind::QPtr(op) = &self.lowerer.cx[new_def.form].kind { + match op { + QPtrOp::HandleArrayIndex + | QPtrOp::BufferData + | QPtrOp::BufferDynLen { .. } + | QPtrOp::Offset(_) + | QPtrOp::DynOffset { .. } => { + self.remove_if_dead_inst_and_parent_block + .push((func_at_inst.position, control_node)); + } + + QPtrOp::FuncLocalVar(_) + | QPtrOp::Load { .. } + | QPtrOp::Store { .. } => {} + } + + if let QPtrOp::Offset(0) = op { + let mut base_ptr = new_def.inputs[0]; + if let Value::DataInstOutput(base_ptr_inst) = base_ptr { + if let Some(&base_ptr_base_ptr) = + self.noop_offsets_to_base_ptr.get(&base_ptr_inst) + { + base_ptr = base_ptr_base_ptr; + } + } + self.noop_offsets_to_base_ptr + .insert(func_at_inst.position, base_ptr); + } + } + *func_at_inst.def() = new_def; } result @ (Ok(Transformed::Unchanged) | Err(_)) => { @@ -727,5 +826,39 @@ impl Transformer for LowerFromSpvPtrInstsInFunc<'_> { } } } + + // NOTE(eddyb) this is done last so that `transform_value_use` only sees + // the lowered `Value`s, not the original ones. + func_at_control_node.reborrow().inner_in_place_transform_with(self); + } + + fn in_place_transform_func_decl(&mut self, func_decl: &mut FuncDecl) { + func_decl.inner_in_place_transform_with(self); + + // Apply all `remove_if_dead_inst_and_parent_block` removals, that are truly unused. + if let DeclDef::Present(func_def_body) = &mut func_decl.def { + let remove_if_dead_inst_and_parent_block = + mem::take(&mut self.remove_if_dead_inst_and_parent_block); + // NOTE(eddyb) reverse order is important, as each removal can reduce + // use counts of an earlier definition, allowing further removal. + for (inst, parent_block) in remove_if_dead_inst_and_parent_block.into_iter().rev() { + if self.data_inst_use_counts.get(inst).is_none() { + // HACK(eddyb) can't really use helpers like `FuncAtMut::def`, + // due to the need to borrow `control_nodes` and `data_insts` + // at the same time - perhaps some kind of `FuncAtMut` position + // types for "where a list is in a parent entity" could be used + // to make this more ergonomic, although the potential need for + // an actual list entity of its own, should be considered. + match &mut func_def_body.control_nodes[parent_block].kind { + ControlNodeKind::Block { insts } => { + insts.remove(inst, &mut func_def_body.data_insts); + } + _ => unreachable!(), + } + + self.remove_value_uses(&func_def_body.at(inst).def().inputs); + } + } + } } } From 53c6cece92291998926f7904664f4f1b615e4373 Mon Sep 17 00:00:00 2001 From: Eduard-Mihai Burtescu Date: Sun, 1 Oct 2023 17:03:29 +0300 Subject: [PATCH 12/22] spv/lift: abandon `Result` for allocating IDs and do it far more eagerly. --- src/qptr/lift.rs | 13 +- src/spv/lift.rs | 575 ++++++++++++++++++++++++----------------------- 2 files changed, 296 insertions(+), 292 deletions(-) diff --git a/src/qptr/lift.rs b/src/qptr/lift.rs index 58dd4f52..69fc221c 100644 --- a/src/qptr/lift.rs +++ b/src/qptr/lift.rs @@ -285,12 +285,15 @@ impl<'a> LiftToSpvPtrs<'a> { attrs: stride_attrs.unwrap_or_default(), kind: TypeKind::SpvInst { spv_inst: spv_opcode.into(), - type_and_const_inputs: [TypeOrConst::Type(element_type)] - .into_iter() - .chain(fixed_len.map(|len| { + type_and_const_inputs: [ + Some(TypeOrConst::Type(element_type)), + fixed_len.map(|len| { TypeOrConst::Const(self.cx.intern(scalar::Const::from_u32(len))) - })) - .collect(), + }), + ] + .into_iter() + .flatten() + .collect(), }, })) } diff --git a/src/spv/lift.rs b/src/spv/lift.rs index 88170778..3ceabd97 100644 --- a/src/spv/lift.rs +++ b/src/spv/lift.rs @@ -7,15 +7,15 @@ use crate::{ cfg, scalar, AddrSpace, Attr, AttrSet, Const, ConstDef, ConstKind, Context, ControlNode, ControlNodeKind, ControlNodeOutputDecl, ControlRegion, ControlRegionInputDecl, DataInst, DataInstDef, DataInstForm, DataInstFormDef, DataInstKind, DeclDef, EntityList, ExportKey, - Exportee, Func, FuncDecl, FuncParam, FxIndexMap, FxIndexSet, GlobalVar, GlobalVarDefBody, - Import, Module, ModuleDebugInfo, ModuleDialect, SelectionKind, Type, TypeDef, TypeKind, - TypeOrConst, Value, + Exportee, Func, FuncDecl, FuncDefBody, FuncParam, FxIndexMap, FxIndexSet, GlobalVar, + GlobalVarDefBody, Import, Module, ModuleDebugInfo, ModuleDialect, SelectionKind, Type, TypeDef, + TypeKind, TypeOrConst, Value, }; use rustc_hash::FxHashMap; use smallvec::SmallVec; use std::borrow::Cow; -use std::collections::{BTreeMap, BTreeSet}; -use std::num::NonZeroU32; +use std::collections::BTreeMap; +use std::num::NonZeroUsize; use std::path::Path; use std::{io, iter, mem, slice}; @@ -75,34 +75,29 @@ impl spv::ModuleDebugInfo { } } -impl FuncDecl { - fn spv_func_type(&self, cx: &Context) -> Type { - let wk = &spec::Spec::get().well_known; - - cx.intern(TypeDef { - attrs: AttrSet::default(), - kind: TypeKind::SpvInst { - spv_inst: wk.OpTypeFunction.into(), - type_and_const_inputs: iter::once(self.ret_type) - .chain(self.params.iter().map(|param| param.ty)) - .map(TypeOrConst::Type) - .collect(), - }, - }) - } -} - -struct NeedsIdsCollector<'a> { +struct IdAllocator<'a, AI: FnMut() -> spv::Id> { cx: &'a Context, module: &'a Module, - ext_inst_imports: BTreeSet<&'a str>, - debug_strings: BTreeSet<&'a str>, + /// ID allocation callback, kept as a closure (instead of having its state + /// be part of `IdAllocator`) to avoid misuse. + alloc_id: AI, + + ids: ModuleIds<'a>, - globals: FxIndexSet, data_inst_forms_seen: FxIndexSet, global_vars_seen: FxIndexSet, - funcs: FxIndexSet, +} + +#[derive(Default)] +struct ModuleIds<'a> { + ext_inst_imports: BTreeMap<&'a str, spv::Id>, + debug_strings: BTreeMap<&'a str, spv::Id>, + + // FIXME(eddyb) use `EntityOrientedDenseMap` here. + globals: FxIndexMap, + // FIXME(eddyb) use `EntityOrientedDenseMap` here. + funcs: FxIndexMap>, } #[derive(Copy, Clone, PartialEq, Eq, Hash)] @@ -111,13 +106,27 @@ enum Global { Const(Const), } -impl Visitor<'_> for NeedsIdsCollector<'_> { +// FIXME(eddyb) should this use ID ranges instead of `SmallVec<[spv::Id; 4]>`? +// FIXME(eddyb) this is inconsistently named with `FuncBodyLifting`. +struct FuncIds<'a> { + spv_func_ret_type: Type, + // FIXME(eddyb) should we even be interning an `OpTypeFunction` in `Context`? + // (it's easier this way, but it could also be tracked in `ModuleIds`) + spv_func_type: Type, + + func_id: spv::Id, + param_ids: SmallVec<[spv::Id; 4]>, + + body: Option>, +} + +impl spv::Id> Visitor<'_> for IdAllocator<'_, AI> { fn visit_attr_set_use(&mut self, attrs: AttrSet) { self.visit_attr_set_def(&self.cx[attrs]); } fn visit_type_use(&mut self, ty: Type) { let global = Global::Type(ty); - if self.globals.contains(&global) { + if self.ids.globals.contains_key(&global) { return; } let ty_def = &self.cx[ty]; @@ -153,11 +162,11 @@ impl Visitor<'_> for NeedsIdsCollector<'_> { } self.visit_type_def(ty_def); - self.globals.insert(global); + self.ids.globals.insert(global, (self.alloc_id)()); } fn visit_const_use(&mut self, ct: Const) { let global = Global::Const(ct); - if self.globals.contains(&global) { + if self.ids.globals.contains_key(&global) { return; } let ct_def = &self.cx[ct]; @@ -179,7 +188,7 @@ impl Visitor<'_> for NeedsIdsCollector<'_> { | ConstKind::PtrToGlobalVar(_) | ConstKind::SpvInst { .. } => { self.visit_const_def(ct_def); - self.globals.insert(global); + self.ids.globals.insert(global, (self.alloc_id)()); } // HACK(eddyb) because this is an `OpString` and needs to go earlier @@ -197,7 +206,7 @@ impl Visitor<'_> for NeedsIdsCollector<'_> { } ); - self.debug_strings.insert(&self.cx[s]); + self.ids.debug_strings.entry(&self.cx[s]).or_insert_with(&mut self.alloc_id); } } } @@ -213,24 +222,55 @@ impl Visitor<'_> for NeedsIdsCollector<'_> { } } fn visit_func_use(&mut self, func: Func) { - if self.funcs.contains(&func) { + if self.ids.funcs.contains_key(&func) { return; } - // NOTE(eddyb) inserting first results in a different function ordering - // in the resulting module, but the order doesn't matter, and we need - // to avoid infinite recursion for recursive functions. - self.funcs.insert(func); - let func_decl = &self.module.funcs[func]; - // FIXME(eddyb) should this be cached in `self.funcs`? - self.visit_type_use(func_decl.spv_func_type(self.cx)); + + // Synthesize an `OpTypeFunction` type (that SPIR-T itself doesn't carry). + let wk = &spec::Spec::get().well_known; + let spv_func_ret_type = func_decl.ret_type; + let spv_func_type = self.cx.intern(TypeKind::SpvInst { + spv_inst: wk.OpTypeFunction.into(), + type_and_const_inputs: iter::once(spv_func_ret_type) + .chain(func_decl.params.iter().map(|param| param.ty)) + .map(TypeOrConst::Type) + .collect(), + }); + self.visit_type_use(spv_func_type); + + // NOTE(eddyb) inserting first produces a different function ordering + // overall in the final module, but the order doesn't matter, and we + // need to avoid infinite recursion for recursive functions. + self.ids.funcs.insert( + func, + FuncIds { + spv_func_ret_type, + spv_func_type, + func_id: (self.alloc_id)(), + param_ids: func_decl.params.iter().map(|_| (self.alloc_id)()).collect(), + body: None, + }, + ); + self.visit_func_decl(func_decl); + + // Handle the body last, to minimize recursion hazards (see comment above). + match &func_decl.def { + DeclDef::Imported(_) => {} + DeclDef::Present(func_def_body) => { + let func_body_lifting = FuncBodyLifting::from_func_def_body(self, func_def_body); + self.ids.funcs.get_mut(&func).unwrap().body = Some(func_body_lifting); + } + } } fn visit_spv_module_debug_info(&mut self, debug_info: &spv::ModuleDebugInfo) { for sources in debug_info.source_languages.values() { // The file operand of `OpSource` has to point to an `OpString`. - self.debug_strings.extend(sources.file_contents.keys().copied().map(|s| &self.cx[s])); + for &s in sources.file_contents.keys() { + self.ids.debug_strings.entry(&self.cx[s]).or_insert_with(&mut self.alloc_id); + } } } fn visit_attr(&mut self, attr: &Attr) { @@ -240,7 +280,10 @@ impl Visitor<'_> for NeedsIdsCollector<'_> { | Attr::SpvAnnotation { .. } | Attr::SpvBitflagsOperand(_) => {} Attr::SpvDebugLine { file_path, .. } => { - self.debug_strings.insert(&self.cx[file_path.0]); + self.ids + .debug_strings + .entry(&self.cx[file_path.0]) + .or_insert_with(&mut self.alloc_id); } } attr.inner_visit_with(self); @@ -260,28 +303,18 @@ impl Visitor<'_> for NeedsIdsCollector<'_> { | DataInstKind::SpvInst(_) => {} DataInstKind::SpvExtInst { ext_set, .. } => { - self.ext_inst_imports.insert(&self.cx[ext_set]); + self.ids + .ext_inst_imports + .entry(&self.cx[ext_set]) + .or_insert_with(&mut self.alloc_id); } } data_inst_form_def.inner_visit_with(self); } } -struct AllocatedIds<'a> { - ext_inst_imports: BTreeMap<&'a str, spv::Id>, - debug_strings: BTreeMap<&'a str, spv::Id>, - - // FIXME(eddyb) use `EntityOrientedDenseMap` here. - globals: FxIndexMap, - // FIXME(eddyb) use `EntityOrientedDenseMap` here. - funcs: FxIndexMap>, -} - -// FIXME(eddyb) should this use ID ranges instead of `SmallVec<[spv::Id; 4]>`? -struct FuncLifting<'a> { - func_id: spv::Id, - param_ids: SmallVec<[spv::Id; 4]>, - +// FIXME(eddyb) this is inconsistently named with `FuncIds`. +struct FuncBodyLifting<'a> { // FIXME(eddyb) use `EntityOrientedDenseMap` here. region_inputs_source: FxHashMap, // FIXME(eddyb) use `EntityOrientedDenseMap` here. @@ -381,42 +414,6 @@ enum Merge { }, } -impl<'a> NeedsIdsCollector<'a> { - fn alloc_ids( - self, - mut alloc_id: impl FnMut() -> Result, - ) -> Result, E> { - let Self { - cx, - module, - ext_inst_imports, - debug_strings, - globals, - data_inst_forms_seen: _, - global_vars_seen: _, - funcs, - } = self; - - Ok(AllocatedIds { - ext_inst_imports: ext_inst_imports - .into_iter() - .map(|name| Ok((name, alloc_id()?))) - .collect::>()?, - debug_strings: debug_strings - .into_iter() - .map(|s| Ok((s, alloc_id()?))) - .collect::>()?, - globals: globals.into_iter().map(|g| Ok((g, alloc_id()?))).collect::>()?, - funcs: funcs - .into_iter() - .map(|func| { - Ok((func, FuncLifting::from_func_decl(cx, &module.funcs[func], &mut alloc_id)?)) - }) - .collect::>()?, - }) - } -} - /// Helper type for deep traversal of the CFG (as a graph of [`CfgPoint`]s), which /// tracks the necessary context for navigating a [`ControlRegion`]/[`ControlNode`]. #[derive(Copy, Clone)] @@ -494,41 +491,37 @@ impl<'a, 'p> FuncAt<'a, CfgCursor<'p>> { impl<'a> FuncAt<'a, ControlRegion> { /// Traverse every [`CfgPoint`] (deeply) contained in this [`ControlRegion`], /// in reverse post-order (RPO), with `f` receiving each [`CfgPoint`] - /// in turn (wrapped in [`CfgCursor`], for further traversal flexibility), - /// and being able to stop iteration by returning `Err`. + /// in turn (wrapped in [`CfgCursor`], for further traversal flexibility). /// /// RPO iteration over a CFG provides certain guarantees, most importantly /// that SSA definitions are visited before any of their uses. - fn rev_post_order_try_for_each( - self, - mut f: impl FnMut(CfgCursor<'_>) -> Result<(), E>, - ) -> Result<(), E> { - self.rev_post_order_try_for_each_inner(&mut f, None) + fn rev_post_order_for_each(self, mut f: impl FnMut(CfgCursor<'_>)) { + self.rev_post_order_for_each_inner(&mut f, None); } - fn rev_post_order_try_for_each_inner( + fn rev_post_order_for_each_inner( self, - f: &mut impl FnMut(CfgCursor<'_>) -> Result<(), E>, + f: &mut impl FnMut(CfgCursor<'_>), parent: Option<&CfgCursor<'_, ControlParent>>, - ) -> Result<(), E> { + ) { let region = self.position; - f(CfgCursor { point: CfgPoint::RegionEntry(region), parent })?; + f(CfgCursor { point: CfgPoint::RegionEntry(region), parent }); for func_at_control_node in self.at_children() { - func_at_control_node.rev_post_order_try_for_each_inner( + func_at_control_node.rev_post_order_for_each_inner( f, &CfgCursor { point: ControlParent::Region(region), parent }, - )?; + ); } - f(CfgCursor { point: CfgPoint::RegionExit(region), parent }) + f(CfgCursor { point: CfgPoint::RegionExit(region), parent }); } } impl<'a> FuncAt<'a, ControlNode> { - fn rev_post_order_try_for_each_inner( + fn rev_post_order_for_each_inner( self, - f: &mut impl FnMut(CfgCursor<'_>) -> Result<(), E>, + f: &mut impl FnMut(CfgCursor<'_>), parent: &CfgCursor<'_, ControlParent>, - ) -> Result<(), E> { + ) { let child_regions: &[_] = match &self.def().kind { ControlNodeKind::Block { .. } => &[], ControlNodeKind::Select { cases, .. } => cases, @@ -537,39 +530,23 @@ impl<'a> FuncAt<'a, ControlNode> { let control_node = self.position; let parent = Some(parent); - f(CfgCursor { point: CfgPoint::ControlNodeEntry(control_node), parent })?; + f(CfgCursor { point: CfgPoint::ControlNodeEntry(control_node), parent }); for ®ion in child_regions { - self.at(region).rev_post_order_try_for_each_inner( + self.at(region).rev_post_order_for_each_inner( f, Some(&CfgCursor { point: ControlParent::ControlNode(control_node), parent }), - )?; + ); } - f(CfgCursor { point: CfgPoint::ControlNodeExit(control_node), parent }) + f(CfgCursor { point: CfgPoint::ControlNodeExit(control_node), parent }); } } -impl<'a> FuncLifting<'a> { - fn from_func_decl( - cx: &Context, - func_decl: &'a FuncDecl, - mut alloc_id: impl FnMut() -> Result, - ) -> Result { - let func_id = alloc_id()?; - let param_ids = func_decl.params.iter().map(|_| alloc_id()).collect::>()?; - - let func_def_body = match &func_decl.def { - DeclDef::Imported(_) => { - return Ok(Self { - func_id, - param_ids, - region_inputs_source: Default::default(), - data_inst_output_ids: Default::default(), - label_ids: Default::default(), - blocks: Default::default(), - }); - } - DeclDef::Present(def) => def, - }; +impl<'a> FuncBodyLifting<'a> { + fn from_func_def_body( + id_allocator: &mut IdAllocator<'_, impl FnMut() -> spv::Id>, + func_def_body: &'a FuncDefBody, + ) -> Self { + let cx = id_allocator.cx; let mut region_inputs_source = FxHashMap::default(); region_inputs_source.insert(func_def_body.body, RegionInputsSource::FuncParams); @@ -590,17 +567,15 @@ impl<'a> FuncLifting<'a> { .def() .inputs .iter() - .map(|&ControlRegionInputDecl { attrs, ty }| { - Ok(Phi { - attrs, - ty, + .map(|&ControlRegionInputDecl { attrs, ty }| Phi { + attrs, + ty, - result_id: alloc_id()?, - cases: FxIndexMap::default(), - default_value: None, - }) + result_id: (id_allocator.alloc_id)(), + cases: FxIndexMap::default(), + default_value: None, }) - .collect::>()? + .collect() } } CfgPoint::RegionExit(_) => SmallVec::new(), @@ -624,17 +599,15 @@ impl<'a> FuncLifting<'a> { loop_body_inputs .iter() .enumerate() - .map(|(i, &ControlRegionInputDecl { attrs, ty })| { - Ok(Phi { - attrs, - ty, - - result_id: alloc_id()?, - cases: FxIndexMap::default(), - default_value: Some(initial_inputs[i]), - }) + .map(|(i, &ControlRegionInputDecl { attrs, ty })| Phi { + attrs, + ty, + + result_id: (id_allocator.alloc_id)(), + cases: FxIndexMap::default(), + default_value: Some(initial_inputs[i]), }) - .collect::>()? + .collect() } _ => SmallVec::new(), } @@ -644,17 +617,15 @@ impl<'a> FuncLifting<'a> { .def() .outputs .iter() - .map(|&ControlNodeOutputDecl { attrs, ty }| { - Ok(Phi { - attrs, - ty, - - result_id: alloc_id()?, - cases: FxIndexMap::default(), - default_value: None, - }) + .map(|&ControlNodeOutputDecl { attrs, ty }| Phi { + attrs, + ty, + + result_id: (id_allocator.alloc_id)(), + cases: FxIndexMap::default(), + default_value: None, }) - .collect::>()?, + .collect(), }; let insts = match point { @@ -836,14 +807,14 @@ impl<'a> FuncLifting<'a> { }; blocks.insert(point, BlockLifting { phis, insts, terminator }); - - Ok(()) }; match &func_def_body.unstructured_cfg { - None => func_def_body.at_body().rev_post_order_try_for_each(visit_cfg_point)?, + None => { + func_def_body.at_body().rev_post_order_for_each(visit_cfg_point); + } Some(cfg) => { for region in cfg.rev_post_order(func_def_body) { - func_def_body.at(region).rev_post_order_try_for_each(&mut visit_cfg_point)?; + func_def_body.at(region).rev_post_order_for_each(&mut visit_cfg_point); } } } @@ -986,31 +957,26 @@ impl<'a> FuncLifting<'a> { .filter(|&func_at_inst| cx[func_at_inst.def().form].output_type.is_some()) .map(|func_at_inst| func_at_inst.position); - Ok(Self { - func_id, - param_ids, + Self { region_inputs_source, data_inst_output_ids: all_insts_with_output - .map(|inst| Ok((inst, alloc_id()?))) - .collect::>()?, - label_ids: blocks - .keys() - .map(|&point| Ok((point, alloc_id()?))) - .collect::>()?, + .map(|inst| (inst, (id_allocator.alloc_id)())) + .collect(), + label_ids: blocks.keys().map(|&point| (point, (id_allocator.alloc_id)())).collect(), blocks, - }) + } } } -/// "Maybe-decorated "lazy" SPIR-V instruction, allowing separately emitting +/// Maybe-decorated "lazy" SPIR-V instruction, allowing separately emitting /// decorations from attributes, and the instruction itself, without eagerly /// allocating all the instructions. #[derive(Copy, Clone)] enum LazyInst<'a, 'b> { Global(Global), OpFunction { - func_id: spv::Id, func_decl: &'a FuncDecl, + func_ids: &'b FuncIds<'a>, }, OpFunctionParameter { param_id: spv::Id, @@ -1020,17 +986,17 @@ enum LazyInst<'a, 'b> { label_id: spv::Id, }, OpPhi { - parent_func: &'b FuncLifting<'a>, + parent_func_ids: &'b FuncIds<'a>, phi: &'b Phi, }, DataInst { - parent_func: &'b FuncLifting<'a>, + parent_func_ids: &'b FuncIds<'a>, result_id: Option, data_inst_def: &'a DataInstDef, }, Merge(Merge), Terminator { - parent_func: &'b FuncLifting<'a>, + parent_func_ids: &'b FuncIds<'a>, terminator: &'b Terminator<'a>, }, OpFunctionEnd, @@ -1040,7 +1006,7 @@ impl LazyInst<'_, '_> { fn result_id_attrs_and_import( self, module: &Module, - ids: &AllocatedIds<'_>, + ids: &ModuleIds<'_>, ) -> (Option, AttrSet, Option) { let cx = module.cx_ref(); @@ -1073,21 +1039,21 @@ impl LazyInst<'_, '_> { }; (Some(ids.globals[&global]), attrs, import) } - Self::OpFunction { func_id, func_decl } => { + Self::OpFunction { func_decl, func_ids } => { let import = match func_decl.def { DeclDef::Imported(import) => Some(import), DeclDef::Present(_) => None, }; - (Some(func_id), func_decl.attrs, import) + (Some(func_ids.func_id), func_decl.attrs, import) } Self::OpFunctionParameter { param_id, param } => (Some(param_id), param.attrs, None), Self::OpLabel { label_id } => (Some(label_id), AttrSet::default(), None), - Self::OpPhi { parent_func: _, phi } => (Some(phi.result_id), phi.attrs, None), - Self::DataInst { parent_func: _, result_id, data_inst_def } => { + Self::OpPhi { parent_func_ids: _, phi } => (Some(phi.result_id), phi.attrs, None), + Self::DataInst { parent_func_ids: _, result_id, data_inst_def } => { (result_id, data_inst_def.attrs, None) } Self::Merge(_) => (None, AttrSet::default(), None), - Self::Terminator { parent_func: _, terminator } => (None, terminator.attrs, None), + Self::Terminator { parent_func_ids: _, terminator } => (None, terminator.attrs, None), Self::OpFunctionEnd => (None, AttrSet::default(), None), } } @@ -1095,12 +1061,12 @@ impl LazyInst<'_, '_> { fn to_inst_and_attrs( self, module: &Module, - ids: &AllocatedIds<'_>, + ids: &ModuleIds<'_>, ) -> (spv::InstWithIds, AttrSet) { let wk = &spec::Spec::get().well_known; let cx = module.cx_ref(); - let value_to_id = |parent_func: &FuncLifting<'_>, v| match v { + let value_to_id = |parent_func_ids: &FuncIds<'_>, v| match v { Value::Const(ct) => match cx[ct].kind { ConstKind::SpvStringLiteralForExtInst(s) => ids.debug_strings[&cx[s]], @@ -1108,23 +1074,30 @@ impl LazyInst<'_, '_> { }, Value::ControlRegionInput { region, input_idx } => { let input_idx = usize::try_from(input_idx).unwrap(); - match parent_func.region_inputs_source.get(®ion) { - Some(RegionInputsSource::FuncParams) => parent_func.param_ids[input_idx], + let parent_func_body_lifting = parent_func_ids.body.as_ref().unwrap(); + match parent_func_body_lifting.region_inputs_source.get(®ion) { + Some(RegionInputsSource::FuncParams) => parent_func_ids.param_ids[input_idx], Some(&RegionInputsSource::LoopHeaderPhis(loop_node)) => { - parent_func.blocks[&CfgPoint::ControlNodeEntry(loop_node)].phis[input_idx] + parent_func_body_lifting.blocks[&CfgPoint::ControlNodeEntry(loop_node)].phis + [input_idx] .result_id } None => { - parent_func.blocks[&CfgPoint::RegionEntry(region)].phis[input_idx].result_id + parent_func_body_lifting.blocks[&CfgPoint::RegionEntry(region)].phis + [input_idx] + .result_id } } } Value::ControlNodeOutput { control_node, output_idx } => { - parent_func.blocks[&CfgPoint::ControlNodeExit(control_node)].phis - [usize::try_from(output_idx).unwrap()] + parent_func_ids.body.as_ref().unwrap().blocks + [&CfgPoint::ControlNodeExit(control_node)] + .phis[usize::try_from(output_idx).unwrap()] .result_id } - Value::DataInstOutput(inst) => parent_func.data_inst_output_ids[&inst], + Value::DataInstOutput(inst) => { + parent_func_ids.body.as_ref().unwrap().data_inst_output_ids[&inst] + } }; let (result_id, attrs, _) = self.result_id_attrs_and_import(module, ids); @@ -1233,7 +1206,7 @@ impl LazyInst<'_, '_> { } } }, - Self::OpFunction { func_id: _, func_decl } => { + Self::OpFunction { func_decl: _, func_ids } => { // FIXME(eddyb) make this less of a search and more of a // lookup by splitting attrs into key and value parts. let func_ctrl = cx[attrs] @@ -1254,10 +1227,9 @@ impl LazyInst<'_, '_> { opcode: wk.OpFunction, imms: iter::once(spv::Imm::Short(wk.FunctionControl, func_ctrl)).collect(), }, - result_type_id: Some(ids.globals[&Global::Type(func_decl.ret_type)]), + result_type_id: Some(ids.globals[&Global::Type(func_ids.spv_func_ret_type)]), result_id, - ids: iter::once(ids.globals[&Global::Type(func_decl.spv_func_type(cx))]) - .collect(), + ids: iter::once(ids.globals[&Global::Type(func_ids.spv_func_type)]).collect(), } } Self::OpFunctionParameter { param_id: _, param } => spv::InstWithIds { @@ -1272,7 +1244,7 @@ impl LazyInst<'_, '_> { result_id, ids: [].into_iter().collect(), }, - Self::OpPhi { parent_func, phi } => spv::InstWithIds { + Self::OpPhi { parent_func_ids, phi } => spv::InstWithIds { without_ids: wk.OpPhi.into(), result_type_id: Some(ids.globals[&Global::Type(phi.ty)]), result_id: Some(phi.result_id), @@ -1280,11 +1252,14 @@ impl LazyInst<'_, '_> { .cases .iter() .flat_map(|(&source_point, &v)| { - [value_to_id(parent_func, v), parent_func.label_ids[&source_point]] + [ + value_to_id(parent_func_ids, v), + parent_func_ids.body.as_ref().unwrap().label_ids[&source_point], + ] }) .collect(), }, - Self::DataInst { parent_func, result_id: _, data_inst_def } => { + Self::DataInst { parent_func_ids, result_id: _, data_inst_def } => { let DataInstFormDef { kind, output_type } = &cx[data_inst_def.form]; let (inst, extra_initial_id_operand) = match spv::Inst::from_canonical_data_inst_kind(kind).ok_or(kind) { @@ -1316,7 +1291,9 @@ impl LazyInst<'_, '_> { result_id, ids: extra_initial_id_operand .into_iter() - .chain(data_inst_def.inputs.iter().map(|&v| value_to_id(parent_func, v))) + .chain( + data_inst_def.inputs.iter().map(|&v| value_to_id(parent_func_ids, v)), + ) .collect(), } } @@ -1341,12 +1318,18 @@ impl LazyInst<'_, '_> { result_id: None, ids: [merge_label_id, continue_label_id].into_iter().collect(), }, - Self::Terminator { parent_func, terminator } => { + Self::Terminator { parent_func_ids, terminator } => { + let parent_func_body_lifting = parent_func_ids.body.as_ref().unwrap(); let mut ids: SmallVec<[_; 4]> = terminator .inputs .iter() - .map(|&v| value_to_id(parent_func, v)) - .chain(terminator.targets.iter().map(|&target| parent_func.label_ids[&target])) + .map(|&v| value_to_id(parent_func_ids, v)) + .chain( + terminator + .targets + .iter() + .map(|&target| parent_func_body_lifting.label_ids[&target]), + ) .collect(); // FIXME(eddyb) move some of this to `spv::canonical`. @@ -1418,19 +1401,6 @@ impl Module { } }; - // Collect uses scattered throughout the module, that require def IDs. - let mut needs_ids_collector = NeedsIdsCollector { - cx: &cx, - module: self, - ext_inst_imports: BTreeSet::new(), - debug_strings: BTreeSet::new(), - globals: FxIndexSet::default(), - data_inst_forms_seen: FxIndexSet::default(), - global_vars_seen: FxIndexSet::default(), - funcs: FxIndexSet::default(), - }; - needs_ids_collector.visit_module(self); - // Because `GlobalVar`s are given IDs by the `Const`s that point to them // (i.e. `ConstKind::PtrToGlobalVar`), any `GlobalVar`s in other positions // require extra care to ensure the ID-giving `Const` is visited. @@ -1443,84 +1413,115 @@ impl Module { }); Global::Const(ptr_to_global_var) }; - for &gv in &needs_ids_collector.global_vars_seen { - needs_ids_collector.globals.insert(global_var_to_id_giving_global(gv)); - } - // IDs can be allocated once we have the full sets needing them, whether - // sorted by contents, or ordered by the first occurence in the module. - let mut id_bound = NonZeroU32::MIN; - let ids = needs_ids_collector.alloc_ids(|| { - let id = id_bound; + // Collect uses scattered throughout the module, allocating IDs for them. + let (ids, id_bound) = { + let mut id_bound = NonZeroUsize::MIN; + let mut id_allocator = IdAllocator { + cx: &cx, + module: self, + alloc_id: || { + let id = id_bound; + id_bound = + id_bound.checked_add(1).expect("overflowing `usize` should be impossible"); + + // NOTE(eddyb) `MAX` is just a placeholder - the check for overflows + // is done below, after all IDs that may be allocated, have been + // (this is in order to not need this closure to return a `Result`). + id.try_into().unwrap_or(spv::Id::new(u32::MAX).unwrap()) + }, + ids: ModuleIds::default(), + data_inst_forms_seen: FxIndexSet::default(), + global_vars_seen: FxIndexSet::default(), + }; + id_allocator.visit_module(self); + + // See comment on `global_var_to_id_giving_global` for why this is here. + for &gv in &id_allocator.global_vars_seen { + id_allocator + .ids + .globals + .entry(global_var_to_id_giving_global(gv)) + .or_insert_with(&mut id_allocator.alloc_id); + } - match id_bound.checked_add(1) { - Some(new_bound) => { - id_bound = new_bound; - Ok(id) - } - None => Err(io::Error::new( + let ids = id_allocator.ids; + + let id_bound = spv::Id::try_from(id_bound).ok().ok_or_else(|| { + io::Error::new( io::ErrorKind::InvalidData, "ID bound of SPIR-V module doesn't fit in 32 bits", - )), - } - })?; + ) + })?; + + (ids, id_bound) + }; // HACK(eddyb) allow `move` closures below to reference `cx` or `ids` // without causing unwanted moves out of them. let (cx, ids) = (&*cx, &ids); let global_and_func_insts = ids.globals.keys().copied().map(LazyInst::Global).chain( - ids.funcs.iter().flat_map(|(&func, func_lifting)| { + ids.funcs.iter().flat_map(|(&func, func_ids)| { let func_decl = &self.funcs[func]; - let func_def_body = match &func_decl.def { - DeclDef::Imported(_) => None, - DeclDef::Present(def) => Some(def), + let body_with_lifting = match (&func_decl.def, &func_ids.body) { + (DeclDef::Imported(_), None) => None, + (DeclDef::Present(def), Some(func_body_lifting)) => { + Some((def, func_body_lifting)) + } + _ => unreachable!(), }; - iter::once(LazyInst::OpFunction { func_id: func_lifting.func_id, func_decl }) - .chain(func_lifting.param_ids.iter().zip(&func_decl.params).map( - |(¶m_id, param)| LazyInst::OpFunctionParameter { param_id, param }, - )) - .chain(func_lifting.blocks.iter().flat_map(move |(point, block)| { + let param_insts = + func_ids.param_ids.iter().zip(&func_decl.params).map(|(¶m_id, param)| { + LazyInst::OpFunctionParameter { param_id, param } + }); + let body_insts = body_with_lifting.map(|(func_def_body, func_body_lifting)| { + func_body_lifting.blocks.iter().flat_map(move |(point, block)| { let BlockLifting { phis, insts, terminator } = block; - iter::once(LazyInst::OpLabel { label_id: func_lifting.label_ids[point] }) - .chain( - phis.iter() - .map(|phi| LazyInst::OpPhi { parent_func: func_lifting, phi }), - ) - .chain( - insts - .iter() - .copied() - .flat_map(move |insts| func_def_body.unwrap().at(insts)) - .map(move |func_at_inst| { - let data_inst_def = func_at_inst.def(); - LazyInst::DataInst { - parent_func: func_lifting, - result_id: cx[data_inst_def.form].output_type.map( - |_| { - func_lifting.data_inst_output_ids - [&func_at_inst.position] - }, - ), - data_inst_def, - } - }), - ) - .chain(terminator.merge.map(|merge| { - LazyInst::Merge(match merge { - Merge::Selection(merge) => { - Merge::Selection(func_lifting.label_ids[&merge]) + iter::once(LazyInst::OpLabel { + label_id: func_body_lifting.label_ids[point], + }) + .chain( + phis.iter() + .map(|phi| LazyInst::OpPhi { parent_func_ids: func_ids, phi }), + ) + .chain( + insts + .iter() + .copied() + .flat_map(move |insts| func_def_body.at(insts)) + .map(move |func_at_inst| { + let data_inst_def = func_at_inst.def(); + LazyInst::DataInst { + parent_func_ids: func_ids, + result_id: cx[data_inst_def.form].output_type.map(|_| { + func_body_lifting.data_inst_output_ids + [&func_at_inst.position] + }), + data_inst_def, } - Merge::Loop { loop_merge, loop_continue } => Merge::Loop { - loop_merge: func_lifting.label_ids[&loop_merge], - loop_continue: func_lifting.label_ids[&loop_continue], - }, - }) - })) - .chain([LazyInst::Terminator { parent_func: func_lifting, terminator }]) - })) + }), + ) + .chain(terminator.merge.map(|merge| { + LazyInst::Merge(match merge { + Merge::Selection(merge) => { + Merge::Selection(func_body_lifting.label_ids[&merge]) + } + Merge::Loop { loop_merge, loop_continue } => Merge::Loop { + loop_merge: func_body_lifting.label_ids[&loop_merge], + loop_continue: func_body_lifting.label_ids[&loop_continue], + }, + }) + })) + .chain([LazyInst::Terminator { parent_func_ids: func_ids, terminator }]) + }) + }); + + iter::once(LazyInst::OpFunction { func_decl, func_ids }) + .chain(param_insts) + .chain(body_insts.into_iter().flatten()) .chain([LazyInst::OpFunctionEnd]) }), ); From 22c8b95f7702c756e5c4c714fd3e51b301aa4c9e Mon Sep 17 00:00:00 2001 From: Eduard-Mihai Burtescu Date: Sun, 1 Oct 2023 17:04:00 +0300 Subject: [PATCH 13/22] spv/lower: don't assign IDs ahead of time. --- src/spv/lower.rs | 226 +++++++++++++++++++++-------------------------- 1 file changed, 99 insertions(+), 127 deletions(-) diff --git a/src/spv/lower.rs b/src/spv/lower.rs index b7a6077c..f8b75b33 100644 --- a/src/spv/lower.rs +++ b/src/spv/lower.rs @@ -845,30 +845,12 @@ impl Module { return Err(invalid("OpFunction without matching OpFunctionEnd")); } - // HACK(eddyb) `OpNop` is useful for defining `DataInst`s before they're - // actually lowered (to be able to refer to their outputs `Value`s). - let mut cached_op_nop_form = None; - let mut get_op_nop_form = || { - *cached_op_nop_form.get_or_insert_with(|| { - cx.intern(DataInstFormDef { - kind: DataInstKind::SpvInst(wk.OpNop.into()), - output_type: None, - }) - }) - }; - // Process function bodies, having seen the whole module. for func_body in pending_func_bodies { let FuncBody { func_id, func, insts: raw_insts } = func_body; let func_decl = &mut module.funcs[func]; - #[derive(Copy, Clone)] - enum LocalIdDef { - Value(Type, Value), - BlockLabel(ControlRegion), - } - #[derive(PartialEq, Eq, Hash)] struct PhiKey { source_block_id: spv::Id, @@ -881,111 +863,68 @@ impl Module { phi_count: u32, } - // Index IDs declared within the function, first. - let mut local_id_defs = FxIndexMap::default(); - // `OpPhi`s are also collected here, to assign them per-edge. + // Gather `OpLabel`s and `OpPhi`s early (so they can be random-accessed). let mut phi_to_values = FxIndexMap::>::default(); let mut block_details = FxIndexMap::::default(); let mut has_blocks = false; { - let mut next_param_idx = 0u32; for raw_inst in &raw_insts { let IntraFuncInst { without_ids: spv::Inst { opcode, ref imms }, result_id, - result_type, .. } = *raw_inst; - if let Some(id) = result_id { - let local_id_def = if opcode == wk.OpFunctionParameter { - let idx = next_param_idx; - next_param_idx = idx.checked_add(1).unwrap(); + if opcode == wk.OpFunctionParameter { + continue; + } - let body = match &func_decl.def { - // `LocalIdDef`s not needed for declarations. - DeclDef::Imported(_) => continue, + let is_entry_block = !has_blocks; + has_blocks = true; - DeclDef::Present(def) => def.body, - }; - LocalIdDef::Value( - result_type.unwrap(), - Value::ControlRegionInput { region: body, input_idx: idx }, - ) - } else { - let is_entry_block = !has_blocks; - has_blocks = true; + let func_def_body = match &mut func_decl.def { + // Error will be emitted later, below. + DeclDef::Imported(_) => continue, + DeclDef::Present(def) => def, + }; - let func_def_body = match &mut func_decl.def { - // Error will be emitted later, below. - DeclDef::Imported(_) => continue, - DeclDef::Present(def) => def, - }; + if opcode == wk.OpLabel { + let block = if is_entry_block { + // A `ControlRegion` was defined earlier, + // to be able to create the `FuncDefBody`. + func_def_body.body + } else { + func_def_body.control_regions.define(&cx, ControlRegionDef::default()) + }; + block_details.insert( + block, + BlockDetails { label_id: result_id.unwrap(), phi_count: 0 }, + ); + } else if opcode == wk.OpPhi { + let (_, block_details) = match block_details.last_mut() { + Some(entry) => entry, + // Error will be emitted later, below. + None => continue, + }; - if opcode == wk.OpLabel { - let block = if is_entry_block { - // A `ControlRegion` was defined earlier, - // to be able to create the `FuncDefBody`. - func_def_body.body - } else { - func_def_body - .control_regions - .define(&cx, ControlRegionDef::default()) - }; - block_details - .insert(block, BlockDetails { label_id: id, phi_count: 0 }); - LocalIdDef::BlockLabel(block) - } else if opcode == wk.OpPhi { - let (¤t_block, block_details) = match block_details.last_mut() - { - Some(entry) => entry, - // Error will be emitted later, below. - None => continue, - }; - - let phi_idx = block_details.phi_count; - block_details.phi_count = phi_idx.checked_add(1).unwrap(); - - assert!(imms.is_empty()); - // FIXME(eddyb) use `array_chunks` when that's stable. - for value_and_source_block_id in raw_inst.ids.chunks(2) { - let &[value_id, source_block_id]: &[_; 2] = - value_and_source_block_id.try_into().unwrap(); - - phi_to_values - .entry(PhiKey { - source_block_id, - target_block_id: block_details.label_id, - target_phi_idx: phi_idx, - }) - .or_default() - .push(value_id); - } + let phi_idx = block_details.phi_count; + block_details.phi_count = phi_idx.checked_add(1).unwrap(); - LocalIdDef::Value( - result_type.unwrap(), - Value::ControlRegionInput { - region: current_block, - input_idx: phi_idx, - }, - ) - } else { - // HACK(eddyb) can't get a `DataInst` without - // defining it (as a dummy) first. - let inst = func_def_body.data_insts.define( - &cx, - DataInstDef { - attrs: AttrSet::default(), - // FIXME(eddyb) cache this form locally. - form: get_op_nop_form(), - inputs: [].into_iter().collect(), - } - .into(), - ); - LocalIdDef::Value(result_type.unwrap(), Value::DataInstOutput(inst)) - } - }; - local_id_defs.insert(id, local_id_def); + assert!(imms.is_empty()); + // FIXME(eddyb) use `array_chunks` when that's stable. + for value_and_source_block_id in raw_inst.ids.chunks(2) { + let &[value_id, source_block_id]: &[_; 2] = + value_and_source_block_id.try_into().unwrap(); + + phi_to_values + .entry(PhiKey { + source_block_id, + target_block_id: block_details.label_id, + target_phi_idx: phi_idx, + }) + .or_default() + .push(value_id); + } } } } @@ -1018,6 +957,21 @@ impl Module { None }; + #[derive(Copy, Clone)] + enum LocalIdDef { + Value(Type, Value), + BlockLabel(ControlRegion), + } + + let mut local_id_defs = FxIndexMap::default(); + + // Labels can be forward-referenced, so always have them present. + local_id_defs.extend( + block_details + .iter() + .map(|(®ion, details)| (details.label_id, LocalIdDef::BlockLabel(region))), + ); + let mut current_block_control_region_and_details = None; for (raw_inst_idx, raw_inst) in raw_insts.iter().enumerate() { let lookahead_raw_inst = @@ -1072,6 +1026,9 @@ impl Module { "unsupported use of {} outside `OpExtInst`", id_def.descr(&cx), ))), + // FIXME(eddyb) scan the rest of the function for any + // instructions returning this ID, to report an invalid + // forward reference (use before def). None => local_id_defs .get(&id) .copied() @@ -1091,11 +1048,17 @@ impl Module { let ty = result_type.unwrap(); params.push(FuncParam { attrs, ty }); if let Some(func_def_body) = &mut func_def_body { - func_def_body - .at_mut_body() - .def() - .inputs - .push(ControlRegionInputDecl { attrs, ty }); + let body_inputs = &mut func_def_body.at_mut_body().def().inputs; + let input_idx = u32::try_from(body_inputs.len()).unwrap(); + body_inputs.push(ControlRegionInputDecl { attrs, ty }); + + local_id_defs.insert( + result_id.unwrap(), + LocalIdDef::Value( + ty, + Value::ControlRegionInput { region: func_def_body.body, input_idx }, + ), + ); } continue; } @@ -1109,8 +1072,7 @@ impl Module { return Err(invalid("block lacks terminator instruction")); } - // A `ControlRegion` (using an empty `Block` `ControlNode` - // as its sole child) was defined earlier, + // An empty `ControlRegion` was defined earlier, // to be able to have an entry in `local_id_defs`. let control_region = match local_id_defs[&result_id.unwrap()] { LocalIdDef::BlockLabel(control_region) => control_region, @@ -1317,9 +1279,24 @@ impl Module { )); } + let ty = result_type.unwrap(); + + let input_idx = + u32::try_from(current_block_control_region_def.inputs.len()).unwrap(); current_block_control_region_def .inputs - .push(ControlRegionInputDecl { attrs, ty: result_type.unwrap() }); + .push(ControlRegionInputDecl { attrs, ty }); + + local_id_defs.insert( + result_id.unwrap(), + LocalIdDef::Value( + ty, + Value::ControlRegionInput { + region: current_block_control_region, + input_idx, + }, + ), + ); } else if [wk.OpSelectionMerge, wk.OpLoopMerge].contains(&opcode) { let is_second_to_last_in_block = lookahead_raw_inst(2) .map_or(true, |next_raw_inst| { @@ -1448,19 +1425,14 @@ impl Module { }) .collect::>()?, }; - let inst = match result_id { - Some(id) => match local_id_defs[&id] { - LocalIdDef::Value(_, Value::DataInstOutput(inst)) => { - // A dummy was defined earlier, to be able to - // have an entry in `local_id_defs`. - func_def_body.data_insts[inst] = data_inst_def.into(); + let inst = func_def_body.data_insts.define(&cx, data_inst_def.into()); - inst - } - _ => unreachable!(), - }, - None => func_def_body.data_insts.define(&cx, data_inst_def.into()), - }; + if let Some(result_id) = result_id { + local_id_defs.insert( + result_id, + LocalIdDef::Value(result_type.unwrap(), Value::DataInstOutput(inst)), + ); + } let current_block_control_node = current_block_control_region_def .children From d710313e283869035fab455aae19ff832ade0201 Mon Sep 17 00:00:00 2001 From: Eduard-Mihai Burtescu Date: Sun, 1 Oct 2023 17:04:31 +0300 Subject: [PATCH 14/22] spv: disaggregate by-value `OpTypeStruct`/`OpTypeArray` inputs/outputs. --- README.md | 2 +- src/cfg.rs | 3 +- src/func_at.rs | 4 +- src/lib.rs | 60 ++- src/print/mod.rs | 239 ++++++++--- src/qptr/analyze.rs | 142 ++++-- src/qptr/layout.rs | 2 +- src/qptr/lift.rs | 150 ++++--- src/qptr/lower.rs | 98 +++-- src/qptr/mod.rs | 2 +- src/spv/canonical.rs | 47 +- src/spv/lift.rs | 705 ++++++++++++++++++++++-------- src/spv/lower.rs | 997 +++++++++++++++++++++++++++++++++---------- src/spv/mod.rs | 313 +++++++++++++- src/spv/spec.rs | 1 + src/transform.rs | 54 ++- src/visit.rs | 36 +- 17 files changed, 2231 insertions(+), 624 deletions(-) diff --git a/README.md b/README.md index 7ca18c9e..582da461 100644 --- a/README.md +++ b/README.md @@ -137,7 +137,7 @@ fn main() -> @location(0) i32 { #[spv.Decoration.Location(Location: 0)] global_var GV0 in spv.StorageClass.Output: s32 -func F0() -> spv.OpTypeVoid { +func F0() { loop(v0: s32 <- 1s32, v1: s32 <- 1s32) { v2 = s.lt(v1, 10s32): bool (v3: s32, v4: s32) = if v2 { diff --git a/src/cfg.rs b/src/cfg.rs index 6a09cff8..183ca8f9 100644 --- a/src/cfg.rs +++ b/src/cfg.rs @@ -51,7 +51,8 @@ pub enum ControlInstKind { /// necessary preconditions for reaching this point, are never met. Unreachable, - /// Leave the current function, optionally returning a value. + /// Leave the current function, returning some number of [`Value`]s, as per + /// the function's signature (`ret_types` in [`FuncDecl`](crate::FuncDecl)). Return, /// Leave the current invocation, similar to returning from every function diff --git a/src/func_at.rs b/src/func_at.rs index e72bd13d..eb1fbc3b 100644 --- a/src/func_at.rs +++ b/src/func_at.rs @@ -116,7 +116,9 @@ impl FuncAt<'_, Value> { Value::ControlNodeOutput { control_node, output_idx } => { self.at(control_node).def().outputs[output_idx as usize].ty } - Value::DataInstOutput(inst) => cx[self.at(inst).def().form].output_type.unwrap(), + Value::DataInstOutput { inst, output_idx } => { + cx[self.at(inst).def().form].output_types[output_idx as usize] + } } } } diff --git a/src/lib.rs b/src/lib.rs index 88aeeca8..31291009 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -495,10 +495,12 @@ pub enum TypeKind { // separately in e.g. `ControlRegionInputDecl`, might be a better approach? QPtr, + // FIXME(eddyb) consider wrapping all of these in an `Rc` like `ConstKind`. SpvInst { spv_inst: spv::Inst, // FIXME(eddyb) find a better name. type_and_const_inputs: SmallVec<[TypeOrConst; 2]>, + value_lowering: spv::ValueLowering, }, /// The type of a [`ConstKind::SpvStringLiteralForExtInst`] constant, i.e. @@ -543,10 +545,12 @@ impl Type { } } -/// Interned handle for a [`ConstDef`](crate::ConstDef) (a constant value). +/// Interned handle for a [`ConstDef`](crate::ConstDef) (a constant [`Value`](crate::Value)). pub use context::Const; -/// Definition for a [`Const`]: a constant value. +/// Definition for a [`Const`]: a constant [`Value`]. +/// +/// See [`Value`] docs for limitations on the types of values, including [`Const`]s. // // FIXME(eddyb) maybe special-case some basic consts like integer literals. #[derive(PartialEq, Eq, Hash)] @@ -704,7 +708,7 @@ pub use context::Func; pub struct FuncDecl { pub attrs: AttrSet, - pub ret_type: Type, + pub ret_types: SmallVec<[Type; 2]>, pub params: SmallVec<[FuncParam; 2]>, @@ -976,7 +980,19 @@ pub use context::DataInstForm; pub struct DataInstFormDef { pub kind: DataInstKind, - pub output_type: Option, + /// Types for all the outputs of instructions with this "form". + /// + /// That is, `output_types[i]` is the type of the [`Value::DataInstOutput`] + /// with `output_idx == i` (see also [`Value`] documentation). + /// + /// Most instructions have `0` or `1` outputs, with these notable exceptions: + /// * calls which return multiple values + /// * SPIR-V instructions which originally produced SPIR-V "aggregates" + /// (`OpTypeStruct`/`OpTypeArray`) before [`spv::lower`] decomposed them + /// * in the general case, [`spv::InstLowering`] tracks original types + // + // FIXME(eddyb) change the inline size of this to fit most instructions. + pub output_types: SmallVec<[Type; 2]>, } #[derive(Clone, PartialEq, Eq, Hash, derive_more::From)] @@ -1002,13 +1018,39 @@ pub enum DataInstKind { QPtr(qptr::QPtrOp), // FIXME(eddyb) should this have `#[from]`? - SpvInst(spv::Inst), + SpvInst(spv::Inst, spv::InstLowering), SpvExtInst { ext_set: InternedStr, inst: u32, + lowering: spv::InstLowering, }, } +/// Use of a value, either constant or defined earlier in the same function. +/// +/// Each `Value` can only have one of these types: +/// * [`scalar`] (`bool`, integer, and floating-point), i.e. [`TypeKind::Scalar`] +/// * vectors (small array of [`scalar`]s) +/// * these are *not* traditional SIMD vectors, but more a form of "compression" +/// (i.e. vector ops often applying the equivalent scalar op per-component), +/// and sometimes also mandated by specs (e.g. some Vulkan `BuiltIn` types) +/// * matrices (small array of vectors) +/// * less fundamental than vectors, may be treated like arrays in the future +/// * pointers and by-value (but still opaque) resource handles +/// * SPIR-V has both opaque resource handles that behave much like pointers, +/// even physical ones (e.g. ray-tracing `OpTypeAccelerationStructureKHR`s), +/// and others that are only loaded from memory just before using them as +/// operands (e.g. images/samplers), and such mismatches in indirection may +/// result in SPIR-T making further distinctions here in the future +/// +/// Notably, "aggregate" types (SPIR-V `OpTypeStruct`/`OpTypeArray`) are excluded, +/// so they have to be (recursively) disaggregated into their constituents, and +/// passed around as separate `Value`s (see also [`DataInstFormDef`] docs). +/// * SPIR-V inherited "by-value aggregates" from LLVM, which supports them under +/// the name "FCA" ("first-class aggregates"), but other IRs (and LLVM passes) +/// avoid them because of their (negative) impact on analyses and transforms, +/// with their main vestigial purpose being to encode multiple return values +/// from functions, which can be done more directly in other IRs (and SPIR-T) #[derive(Copy, Clone, PartialEq, Eq)] pub enum Value { Const(Const), @@ -1032,6 +1074,10 @@ pub enum Value { output_idx: u32, }, - /// The output value of a [`DataInst`]. - DataInstOutput(DataInst), + /// One of the outputs produced by a [`DataInst`], with its type given by + /// `cx[data_insts[inst].form].output_types[output_idx]`. + DataInstOutput { + inst: DataInst, + output_idx: u32, + }, } diff --git a/src/print/mod.rs b/src/print/mod.rs index 7618cbe2..553f21ef 100644 --- a/src/print/mod.rs +++ b/src/print/mod.rs @@ -199,7 +199,7 @@ enum Use { // should just use `Value` and assert it's never `Const`? ControlRegionInput { region: ControlRegion, input_idx: u32 }, ControlNodeOutput { control_node: ControlNode, output_idx: u32 }, - DataInstOutput(DataInst), + DataInstOutput { inst: DataInst, output_idx: u32 }, // NOTE(eddyb) these overlap somewhat with other cases, but they're always // generated, even when there is no "use", for `multiversion` alignment. @@ -218,7 +218,7 @@ impl From for Use { Value::ControlNodeOutput { control_node, output_idx } => { Use::ControlNodeOutput { control_node, output_idx } } - Value::DataInstOutput(inst) => Use::DataInstOutput(inst), + Value::DataInstOutput { inst, output_idx } => Use::DataInstOutput { inst, output_idx }, } } } @@ -237,7 +237,7 @@ impl Use { Self::ControlRegionInput { .. } | Self::ControlNodeOutput { .. } - | Self::DataInstOutput(_) => ("", "v"), + | Self::DataInstOutput { .. } => ("", "v"), Self::AlignmentAnchorForControlRegion(_) | Self::AlignmentAnchorForControlNode(_) @@ -540,7 +540,7 @@ impl<'a> Visitor<'a> for Plan<'a> { let wk = &spv::spec::Spec::get().well_known; match &self.cx[gv_decl.type_of_ptr_to].kind { - TypeKind::SpvInst { spv_inst, type_and_const_inputs } + TypeKind::SpvInst { spv_inst, type_and_const_inputs, .. } if spv_inst.opcode == wk.OpTypePointer => { match type_and_const_inputs[..] { @@ -759,7 +759,7 @@ impl<'a> Printer<'a> { if let Use::ControlRegionLabel(_) | Use::ControlRegionInput { .. } | Use::ControlNodeOutput { .. } - | Use::DataInstOutput(_) = use_kind + | Use::DataInstOutput { .. } = use_kind { return (use_kind, UseStyle::Inline); } @@ -793,7 +793,7 @@ impl<'a> Printer<'a> { Use::ControlRegionLabel(_) | Use::ControlRegionInput { .. } | Use::ControlNodeOutput { .. } - | Use::DataInstOutput(_) + | Use::DataInstOutput { .. } | Use::AlignmentAnchorForControlRegion(_) | Use::AlignmentAnchorForControlNode(_) | Use::AlignmentAnchorForDataInst(_) => unreachable!(), @@ -846,7 +846,7 @@ impl<'a> Printer<'a> { Use::ControlRegionLabel(_) | Use::ControlRegionInput { .. } | Use::ControlNodeOutput { .. } - | Use::DataInstOutput(_) + | Use::DataInstOutput { .. } | Use::AlignmentAnchorForControlRegion(_) | Use::AlignmentAnchorForControlNode(_) | Use::AlignmentAnchorForDataInst(_) => { @@ -869,7 +869,7 @@ impl<'a> Printer<'a> { | Use::ControlRegionLabel(_) | Use::ControlRegionInput { .. } | Use::ControlNodeOutput { .. } - | Use::DataInstOutput(_) + | Use::DataInstOutput { .. } | Use::AlignmentAnchorForControlRegion(_) | Use::AlignmentAnchorForControlNode(_) | Use::AlignmentAnchorForDataInst(_) => { @@ -950,15 +950,28 @@ impl<'a> Printer<'a> { if let ControlNodeKind::Block { insts } = *kind { for func_at_inst in func_def_body.at(insts) { - define( - Use::AlignmentAnchorForDataInst(func_at_inst.position), - None, - ); + let inst = func_at_inst.position; + define(Use::AlignmentAnchorForDataInst(inst), None); + let inst_def = func_at_inst.def(); - if cx[inst_def.form].output_type.is_some() { + let inst_form_def = &cx[inst_def.form]; + let attrs_for_output = Some(inst_def.attrs).filter(|_| { + inst_form_def.output_types.len() == 1 + && match &inst_form_def.kind { + DataInstKind::SpvInst(_, lowering) + | DataInstKind::SpvExtInst { lowering, .. } => { + lowering.disaggregated_output.is_none() + } + _ => true, + } + }); + for (i, _) in inst_form_def.output_types.iter().enumerate() { define( - Use::DataInstOutput(func_at_inst.position), - Some(inst_def.attrs), + Use::DataInstOutput { + inst, + output_idx: i.try_into().unwrap(), + }, + attrs_for_output, ); } } @@ -1014,7 +1027,9 @@ impl<'a> Printer<'a> { Use::ControlRegionInput { .. } | Use::ControlNodeOutput { .. } - | Use::DataInstOutput(_) => (&mut value_counter, use_styles.get_mut(&use_kind)), + | Use::DataInstOutput { .. } => { + (&mut value_counter, use_styles.get_mut(&use_kind)) + } Use::AlignmentAnchorForControlRegion(_) | Use::AlignmentAnchorForControlNode(_) @@ -1560,7 +1575,7 @@ impl Use { Self::ControlRegionLabel(_) | Self::ControlRegionInput { .. } | Self::ControlNodeOutput { .. } - | Self::DataInstOutput(_) => "_".into(), + | Self::DataInstOutput { .. } => "_".into(), Self::AlignmentAnchorForControlRegion(_) | Self::AlignmentAnchorForControlNode(_) @@ -2378,15 +2393,16 @@ impl Print for TypeDef { // FIXME(eddyb) should this be shortened to `qtr`? TypeKind::QPtr => printer.declarative_keyword_style().apply("qptr").into(), - TypeKind::SpvInst { spv_inst, type_and_const_inputs } => printer.pretty_spv_inst( - printer.spv_op_style(), - spv_inst.opcode, - &spv_inst.imms, - type_and_const_inputs.iter().map(|&ty_or_ct| match ty_or_ct { - TypeOrConst::Type(ty) => ty.print(printer), - TypeOrConst::Const(ct) => ct.print(printer), - }), - ), + TypeKind::SpvInst { spv_inst, type_and_const_inputs, .. } => printer + .pretty_spv_inst( + printer.spv_op_style(), + spv_inst.opcode, + &spv_inst.imms, + type_and_const_inputs.iter().map(|&ty_or_ct| match ty_or_ct { + TypeOrConst::Type(ty) => ty.print(printer), + TypeOrConst::Const(ct) => ct.print(printer), + }), + ), TypeKind::SpvStringLiteralForExtInst => pretty::Fragment::new([ printer.error_style().apply("type_of").into(), "(".into(), @@ -2648,7 +2664,7 @@ impl Print for GlobalVarDecl { printer.pretty_type_ascription_suffix(ty) } }, - TypeKind::SpvInst { spv_inst, type_and_const_inputs } + TypeKind::SpvInst { spv_inst, type_and_const_inputs, .. } if spv_inst.opcode == wk.OpTypePointer => { match type_and_const_inputs[..] { @@ -2704,7 +2720,19 @@ impl Print for AddrSpace { impl Print for FuncDecl { type Output = AttrsAndDef; fn print(&self, printer: &Printer<'_>) -> AttrsAndDef { - let Self { attrs, ret_type, params, def } = self; + let Self { attrs, ret_types, params, def } = self; + + let sig_ret = if !ret_types.is_empty() { + let mut ret_types = ret_types.iter().map(|ty| ty.print(printer)); + let ret_type = if ret_types.len() == 1 { + ret_types.next().unwrap() + } else { + pretty::join_comma_sep("(", ret_types, ")") + }; + pretty::Fragment::new([" -> ".into(), ret_type]) + } else { + pretty::Fragment::default() + }; let sig = pretty::Fragment::new([ pretty::join_comma_sep( @@ -2722,8 +2750,7 @@ impl Print for FuncDecl { }), ")", ), - " -> ".into(), - ret_type.print(printer), + sig_ret, ]); let def_without_name = match def { @@ -3009,6 +3036,21 @@ impl Print for ControlNodeOutputDecl { } } +impl Print for spv::ReaggregatedIdOperand<'_, Value> { + type Output = pretty::Fragment; + fn print(&self, printer: &Printer<'_>) -> pretty::Fragment { + match *self { + Self::Direct(v) => v.print(printer), + // FIXME(eddyb) should this be recursive? it's not on the + // output side, and we largely don't care about nesting. + Self::Aggregate { ty, leaves } => pretty::Fragment::new([ + pretty::join_comma_sep("(", leaves.iter().map(|v| v.print(printer)), ")"), + printer.pretty_type_ascription_suffix(ty), + ]), + } + } +} + impl Print for FuncAt<'_, DataInst> { type Output = pretty::Fragment; fn print(&self, printer: &Printer<'_>) -> pretty::Fragment { @@ -3016,12 +3058,38 @@ impl Print for FuncAt<'_, DataInst> { let attrs = attrs.print(printer); - let DataInstFormDef { kind, output_type } = &printer.cx[*form]; + let DataInstFormDef { kind, output_types } = &printer.cx[*form]; - let mut output_use_to_print_as_lhs = - output_type.map(|_| Use::DataInstOutput(self.position)); + // NOTE(eddyb) the LHS types and the ascryption type don't have to line up, + // all the edge cases (likely only single-leaf aggregates) are handled + // by comparing the types being printed (and showing both if not redundant). + let mut output_uses_for_lhs = if !output_types.is_empty() { + Some(output_types.iter().enumerate().map(|(output_idx, &output_type)| { + ( + Use::DataInstOutput { + inst: self.position, + output_idx: output_idx.try_into().unwrap(), + }, + output_type, + ) + })) + } else { + None + }; - let mut output_type_to_print = *output_type; + let mut output_type_for_ascription_suffix = match kind { + DataInstKind::Scalar(_) + | DataInstKind::Vector(_) + | DataInstKind::FuncCall(_) + | DataInstKind::QPtr(_) => None, + DataInstKind::SpvInst(_, lowering) | DataInstKind::SpvExtInst { lowering, .. } => { + lowering.disaggregated_output + } + } + .or_else(|| match output_types[..] { + [ty] => Some(ty), + _ => None, + }); // FIXME(eddyb) should this be a method on `scalar::Op` instead? let print_scalar = |op: scalar::Op| { @@ -3035,7 +3103,7 @@ impl Print for FuncAt<'_, DataInst> { ]) }; - let def_without_type = match kind { + let def_without_types = match kind { &DataInstKind::Scalar(op) => pretty::Fragment::new([ print_scalar(op), pretty::join_comma_sep("(", inputs.iter().map(|v| v.print(printer)), ")"), @@ -3230,38 +3298,45 @@ impl Print for FuncAt<'_, DataInst> { ]) } - DataInstKind::SpvInst(inst) => printer.pretty_spv_inst( + DataInstKind::SpvInst(inst, lowering) => printer.pretty_spv_inst( printer.spv_op_style(), inst.opcode, &inst.imms, - inputs.iter().map(|v| v.print(printer)), + lowering.reaggreate_inputs(inputs).map(|o| o.print(printer)), ), - &DataInstKind::SpvExtInst { ext_set, inst } => { + DataInstKind::SpvExtInst { ext_set, inst, lowering } => { let spv_spec = spv::spec::Spec::get(); let wk = &spv_spec.well_known; + // HACK(eddyb) prevent accidentally using non-reaggregated `inputs`. + let inputs = lowering.reaggreate_inputs(inputs); + // HACK(eddyb) hide `OpTypeVoid` types, as they're effectively // the default, and not meaningful *even if* the resulting // value is "used" in a kind of "untyped token" way. - output_type_to_print = output_type_to_print.filter(|&ty| { - let is_void = match &printer.cx[ty].kind { - TypeKind::SpvInst { spv_inst, .. } => spv_inst.opcode == wk.OpTypeVoid, - _ => false, - }; - !is_void - }); + output_type_for_ascription_suffix = + output_type_for_ascription_suffix.filter(|&ty| { + let is_void = match &printer.cx[ty].kind { + TypeKind::SpvInst { spv_inst, .. } => spv_inst.opcode == wk.OpTypeVoid, + _ => false, + }; + !is_void + }); // HACK(eddyb) only keep around untyped outputs if they're used. - if output_type_to_print.is_none() { - output_use_to_print_as_lhs = output_use_to_print_as_lhs.filter(|output_use| { + if output_type_for_ascription_suffix.is_none() { + output_uses_for_lhs = output_uses_for_lhs.filter(|output_uses_with_types| { + assert_eq!(output_uses_with_types.len(), 1); + let (output_use, _output_type) = + output_uses_with_types.clone().next().unwrap(); printer .use_styles - .get(output_use) + .get(&output_use) .is_some_and(|style| !matches!(style, UseStyle::Inline)) }); } // FIXME(eddyb) this may get expensive, cache it? - let ext_set_name = &printer.cx[ext_set]; + let ext_set_name = &printer.cx[*ext_set]; let lowercase_ext_set_name = ext_set_name.to_ascii_lowercase(); let (ext_set_alias, known_inst_desc) = (spv_spec .get_ext_inst_set_by_lowercase_name(&lowercase_ext_set_name)) @@ -3271,7 +3346,7 @@ impl Print for FuncAt<'_, DataInst> { .map_or((&None, None), |ext_inst_set| { // FIXME(eddyb) check that these aliases are unique // across the entire output before using them! - (&ext_inst_set.short_alias, ext_inst_set.instructions.get(&inst)) + (&ext_inst_set.short_alias, ext_inst_set.instructions.get(inst)) }); // FIXME(eddyb) extract and separate out the version? @@ -3289,8 +3364,8 @@ impl Print for FuncAt<'_, DataInst> { Str(&'a str), U32(u32), } - let pseudo_imm_from_value = |v: Value| { - if let Value::Const(ct) = v { + let pseudo_imm_from_input = |v: spv::ReaggregatedIdOperand<'_, Value>| { + if let spv::ReaggregatedIdOperand::Direct(Value::Const(ct)) = v { match &printer.cx[ct].kind { ConstKind::Undef | ConstKind::Vector(_) @@ -3311,10 +3386,8 @@ impl Print for FuncAt<'_, DataInst> { }; let debuginfo_with_pseudo_imm_inputs: Option> = known_inst_desc - .filter(|inst_desc| { - inst_desc.is_debuginfo && output_use_to_print_as_lhs.is_none() - }) - .and_then(|_| inputs.iter().copied().map(pseudo_imm_from_value).collect()); + .filter(|inst_desc| inst_desc.is_debuginfo && output_uses_for_lhs.is_none()) + .and_then(|_| inputs.clone().map(pseudo_imm_from_input).collect()); let printing_debuginfo_as_comment = debuginfo_with_pseudo_imm_inputs.is_some(); let [spv_base_style, string_literal_style, numeric_literal_style] = @@ -3394,9 +3467,9 @@ impl Print for FuncAt<'_, DataInst> { } else { pretty::join_comma_sep( "(", - inputs.iter().zip(operand_names).map(|(&input, name)| { + inputs.zip(operand_names).map(|(input, name)| { // HACK(eddyb) no need to wrap strings in `OpString(...)`. - let printed_input = match pseudo_imm_from_value(input) { + let printed_input = match pseudo_imm_from_input(input) { Some(PseudoImm::Str(s)) => printer.pretty_string_literal(s), _ => input.print(printer), }; @@ -3426,8 +3499,8 @@ impl Print for FuncAt<'_, DataInst> { }; let def_without_name = pretty::Fragment::new([ - def_without_type, - output_type_to_print + def_without_types, + output_type_for_ascription_suffix .map(|ty| printer.pretty_type_ascription_suffix(ty)) .unwrap_or_default(), ]); @@ -3438,11 +3511,33 @@ impl Print for FuncAt<'_, DataInst> { def_without_name, ]); + let outputs_lhs = output_uses_for_lhs.map(|output_uses_with_types| { + // NOTE(eddyb) adding a type to a single output on the LHS can only + // be needed when *a different type* was shown via type ascription. + if output_uses_with_types.len() == 1 { + let (output_use, output_type) = output_uses_with_types.clone().next().unwrap(); + let needs_lhs_type = + output_type_for_ascription_suffix.is_some_and(|ty| output_type != ty); + if !needs_lhs_type { + return output_use.print_as_def(printer); + } + } + + pretty::join_comma_sep( + "(", + output_uses_with_types.map(|(output_use, output_type)| { + pretty::Fragment::new([ + output_use.print_as_def(printer), + printer.pretty_type_ascription_suffix(output_type), + ]) + }), + ")", + ) + }); + AttrsAndDef { attrs, def_without_name }.insert_name_before_def( - output_use_to_print_as_lhs - .map(|output_use| { - pretty::Fragment::new([output_use.print_as_def(printer), " = ".into()]) - }) + outputs_lhs + .map(|outputs_lhs| pretty::Fragment::new([outputs_lhs, " = ".into()])) .unwrap_or_default(), ) } @@ -3482,10 +3577,18 @@ impl Print for cfg::ControlInst { cfg::ControlInstKind::Return => { // FIXME(eddyb) use `targets.is_empty()` when that is stabilized. assert!(targets.len() == 0); - match inputs[..] { - [] => kw("return"), - [v] => pretty::Fragment::new([kw("return"), " ".into(), v.print(printer)]), - _ => unreachable!(), + if inputs.is_empty() { + kw("return") + } else { + let inputs = match inputs[..] { + [v] => v.print(printer), + _ => pretty::join_comma_sep( + "(", + inputs.iter().map(|v| v.print(printer)), + ")", + ), + }; + pretty::Fragment::new([kw("return"), " ".into(), inputs]) } } cfg::ControlInstKind::ExitInvocation(cfg::ExitInvocationKind::SpvInst(spv::Inst { diff --git a/src/qptr/analyze.rs b/src/qptr/analyze.rs index 47160ad4..d95517a3 100644 --- a/src/qptr/analyze.rs +++ b/src/qptr/analyze.rs @@ -805,7 +805,7 @@ impl<'a> InferUsage<'a> { DeclDef::Imported(_) => continue, }; - for (v, usage) in usage_or_err_attrs_to_attach { + for (v, mut usage) in usage_or_err_attrs_to_attach { let attrs = match v { Value::Const(_) => unreachable!(), Value::ControlRegionInput { region, input_idx } => { @@ -817,8 +817,45 @@ impl<'a> InferUsage<'a> { [output_idx as usize] .attrs } - Value::DataInstOutput(data_inst) => { - &mut func_def_body.at_mut(data_inst).def().attrs + Value::DataInstOutput { inst, output_idx } => { + let data_inst_def = func_def_body.at_mut(inst).def(); + + // HACK(eddyb) there are no legal multiple-output + // instructions, where one of the outputs is a + // pointer, and there are no per-output attributes, + // so guard against misunderstandings herhe. + if output_idx != 0 + || self.cx[data_inst_def.form].output_types.len() != 1 + { + usage = Err(AnalysisError(match usage { + Ok(usage) => { + let attr: AttrSet = self.cx.intern(AttrSetDef { + attrs: [Attr::QPtr(QPtrAttr::Usage(OrdAssertEq( + usage, + )))] + .into(), + }); + Diag::bug([ + format!( + "cannot attach attribute to \ + output #{output_idx} of \ + multi-output instruction:\n" + ) + .into(), + attr.into(), + ]) + } + Err(AnalysisError(mut diag)) => { + diag.message.insert( + 0, + format!("output #{output_idx}: ").into(), + ); + diag + } + })); + } + + &mut data_inst_def.attrs } }; match usage { @@ -887,13 +924,28 @@ impl<'a> InferUsage<'a> { let mut all_data_insts = CollectAllDataInsts::default(); func_def_body.inner_visit_with(&mut all_data_insts); - let mut data_inst_output_usages = FxHashMap::default(); + let mut data_inst_to_per_output_usage: FxHashMap<_, SmallVec<[Option<_>; 2]>> = + FxHashMap::default(); for insts in all_data_insts.0.into_iter().rev() { for func_at_inst in func_def_body.at(insts).into_iter().rev() { let data_inst = func_at_inst.position; let data_inst_def = func_at_inst.def(); let data_inst_form_def = &cx[data_inst_def.form]; - let output_usage = data_inst_output_usages.remove(&data_inst).flatten(); + + // FIXME(eddyb) should remaining `Some`s in `per_output_usage` + // be attached to the instruction, after all the handling below? + let mut per_output_usage = + data_inst_to_per_output_usage.remove(&data_inst).unwrap_or_default(); + + // HACK(eddyb) this may be a bit wasteful, but it avoids + // complicating acessing `per_output_usage` below, and + // most instructions should only have at most two outputs. + { + let expected = data_inst_form_def.output_types.len(); + if per_output_usage.len() < expected { + per_output_usage.extend((per_output_usage.len()..expected).map(|_| None)); + } + } let mut generate_usage = |this: &mut Self, ptr: Value, new_usage| { let slot = match ptr { @@ -906,8 +958,11 @@ impl<'a> InferUsage<'a> { // or actually support by adding the usage attribute // in the same manner (if it makes sense to do so). _ => { + // FIXME(eddyb) this output may not even exist, + // there should be a different way to have a + // `Diag` get attached to a whole `DataInst`. usage_or_err_attrs_to_attach.push(( - Value::DataInstOutput(data_inst), + Value::DataInstOutput { inst: data_inst, output_idx: 0 }, Err(AnalysisError(Diag::bug([ "unsupported pointer constant `".into(), ct.into(), @@ -930,8 +985,13 @@ impl<'a> InferUsage<'a> { )); return; } - Value::DataInstOutput(ptr_inst) => { - data_inst_output_usages.entry(ptr_inst).or_default() + Value::DataInstOutput { inst: ptr_inst, output_idx } => { + let i = output_idx as usize; + let slots = data_inst_to_per_output_usage.entry(ptr_inst).or_default(); + if i >= slots.len() { + slots.extend((slots.len()..=i).map(|_| None)); + } + &mut slots[i] } }; *slot = Some(match slot.take() { @@ -958,33 +1018,48 @@ impl<'a> InferUsage<'a> { } } FuncInferUsageState::InProgress => { + // FIXME(eddyb) this output may not even exist, + // there should be a different way to have a + // `Diag` get attached to a whole `DataInst`. usage_or_err_attrs_to_attach.push(( - Value::DataInstOutput(data_inst), + Value::DataInstOutput { inst: data_inst, output_idx: 0 }, Err(AnalysisError(Diag::bug([ "unsupported recursive call".into() ]))), )); } }; - if data_inst_form_def.output_type.map_or(false, is_qptr) { - if let Some(usage) = output_usage { - usage_or_err_attrs_to_attach - .push((Value::DataInstOutput(data_inst), usage)); + for (i, &ty) in data_inst_form_def.output_types.iter().enumerate() { + if is_qptr(ty) { + if let Some(usage) = per_output_usage[i].take() { + usage_or_err_attrs_to_attach.push(( + Value::DataInstOutput { + inst: data_inst, + output_idx: i.try_into().unwrap(), + }, + usage, + )); + } } } } DataInstKind::QPtr(QPtrOp::FuncLocalVar(_)) => { - if let Some(usage) = output_usage { - usage_or_err_attrs_to_attach - .push((Value::DataInstOutput(data_inst), usage)); + assert_eq!(per_output_usage.len(), 1); + if let Some(usage) = per_output_usage[0].take() { + usage_or_err_attrs_to_attach.push(( + Value::DataInstOutput { inst: data_inst, output_idx: 0 }, + usage, + )); } } DataInstKind::QPtr(QPtrOp::HandleArrayIndex) => { + assert_eq!(per_output_usage.len(), 1); generate_usage( self, data_inst_def.inputs[0], - output_usage + per_output_usage[0] + .take() .unwrap_or_else(|| { Err(AnalysisError(Diag::bug([ "HandleArrayIndex: unknown element".into(), @@ -999,10 +1074,12 @@ impl<'a> InferUsage<'a> { ); } DataInstKind::QPtr(QPtrOp::BufferData) => { + assert_eq!(per_output_usage.len(), 1); generate_usage( self, data_inst_def.inputs[0], - output_usage + per_output_usage[0] + .take() .unwrap_or(Ok(QPtrUsage::Memory(QPtrMemUsage::UNUSED))) .and_then(|usage| { let usage = match usage { @@ -1051,10 +1128,11 @@ impl<'a> InferUsage<'a> { ); } &DataInstKind::QPtr(QPtrOp::Offset(offset)) => { + assert_eq!(per_output_usage.len(), 1); generate_usage( self, data_inst_def.inputs[0], - output_usage + per_output_usage[0].take() .unwrap_or(Ok(QPtrUsage::Memory(QPtrMemUsage::UNUSED))) .and_then(|usage| { let usage = match usage { @@ -1093,10 +1171,11 @@ impl<'a> InferUsage<'a> { ); } DataInstKind::QPtr(QPtrOp::DynOffset { stride, index_bounds }) => { + assert_eq!(per_output_usage.len(), 1); generate_usage( self, data_inst_def.inputs[0], - output_usage + per_output_usage[0].take() .unwrap_or(Ok(QPtrUsage::Memory(QPtrMemUsage::UNUSED))) .and_then(|usage| { let usage = match usage { @@ -1151,9 +1230,7 @@ impl<'a> InferUsage<'a> { op @ (QPtrOp::Load { offset } | QPtrOp::Store { offset }), ) => { let (op_name, access_type) = match op { - QPtrOp::Load { .. } => { - ("Load", data_inst_form_def.output_type.unwrap()) - } + QPtrOp::Load { .. } => ("Load", data_inst_form_def.output_types[0]), QPtrOp::Store { .. } => { ("Store", func_at_inst.at(data_inst_def.inputs[1]).type_of(&cx)) } @@ -1228,7 +1305,8 @@ impl<'a> InferUsage<'a> { ); } - DataInstKind::SpvInst(_) | DataInstKind::SpvExtInst { .. } => { + DataInstKind::SpvInst(_, lowering) + | DataInstKind::SpvExtInst { lowering, .. } => { let mut has_from_spv_ptr_output_attr = false; for attr in &cx[data_inst_def.attrs].attrs { match *attr { @@ -1253,7 +1331,8 @@ impl<'a> InferUsage<'a> { // since it would conflict with our // own `Block`-annotated wrapper. shapes::Handle::Buffer(..) => { - return Err(AnalysisError(Diag::bug(["ToSpvPtrInput: whole Buffer ambiguous (handle vs buffer data)".into()]) + return Err(AnalysisError( + Diag::bug(["ToSpvPtrInput: whole Buffer ambiguous (handle vs buffer data)".into()]) )); } }; @@ -1266,7 +1345,8 @@ impl<'a> InferUsage<'a> { // a generated type that matches the // desired `pointee` type. TypeLayout::HandleArray(..) => { - Err(AnalysisError(Diag::bug(["ToSpvPtrInput: whole handle array unrepresentable".into()]) + Err(AnalysisError( + Diag::bug(["ToSpvPtrInput: whole handle array unrepresentable".into()]) )) } TypeLayout::Concrete(concrete) => { @@ -1299,10 +1379,14 @@ impl<'a> InferUsage<'a> { } if has_from_spv_ptr_output_attr { + assert!(lowering.disaggregated_output.is_none()); + assert_eq!(per_output_usage.len(), 1); // FIXME(eddyb) merge with `FromSpvPtrOutput`'s `pointee`. - if let Some(usage) = output_usage { - usage_or_err_attrs_to_attach - .push((Value::DataInstOutput(data_inst), usage)); + if let Some(usage) = per_output_usage[0].take() { + usage_or_err_attrs_to_attach.push(( + Value::DataInstOutput { inst: data_inst, output_idx: 0 }, + usage, + )); } } } diff --git a/src/qptr/layout.rs b/src/qptr/layout.rs index 0617ddb7..8d550428 100644 --- a/src/qptr/layout.rs +++ b/src/qptr/layout.rs @@ -356,7 +356,7 @@ impl<'a> LayoutCache<'a> { ["`layout_of(qptr)` (already lowered?)".into()], ))); } - TypeKind::SpvInst { spv_inst, type_and_const_inputs } => { + TypeKind::SpvInst { spv_inst, type_and_const_inputs, .. } => { (spv_inst, type_and_const_inputs) } TypeKind::SpvStringLiteralForExtInst => { diff --git a/src/qptr/lift.rs b/src/qptr/lift.rs index 69fc221c..a793cdd5 100644 --- a/src/qptr/lift.rs +++ b/src/qptr/lift.rs @@ -160,7 +160,7 @@ impl<'a> LiftToSpvPtrs<'a> { // FIXME(eddyb) deduplicate with `qptr::lower`. fn as_spv_ptr_type(&self, ty: Type) -> Option<(AddrSpace, Type)> { match &self.cx[ty].kind { - TypeKind::SpvInst { spv_inst, type_and_const_inputs } + TypeKind::SpvInst { spv_inst, type_and_const_inputs, .. } if spv_inst.opcode == self.wk.OpTypePointer => { let sc = match spv_inst.imms[..] { @@ -184,13 +184,16 @@ impl<'a> LiftToSpvPtrs<'a> { AddrSpace::Handles => unreachable!(), AddrSpace::SpvStorageClass(storage_class) => storage_class, }; - self.cx.intern(TypeKind::SpvInst { - spv_inst: spv::Inst { + self.cx.intern( + spv::Inst { opcode: wk.OpTypePointer, imms: [spv::Imm::Short(wk.StorageClass, storage_class)].into_iter().collect(), - }, - type_and_const_inputs: [TypeOrConst::Type(pointee_type)].into_iter().collect(), - }) + } + .into_canonical_type_with( + &self.cx, + [TypeOrConst::Type(pointee_type)].into_iter().collect(), + ), + ) } fn pointee_type_for_usage(&self, usage: &QPtrUsage) -> Result { @@ -283,9 +286,9 @@ impl<'a> LiftToSpvPtrs<'a> { Ok(self.cx.intern(TypeDef { attrs: stride_attrs.unwrap_or_default(), - kind: TypeKind::SpvInst { - spv_inst: spv_opcode.into(), - type_and_const_inputs: [ + kind: spv::Inst::from(spv_opcode).into_canonical_type_with( + &self.cx, + [ Some(TypeOrConst::Type(element_type)), fixed_len.map(|len| { TypeOrConst::Const(self.cx.intern(scalar::Const::from_u32(len))) @@ -294,7 +297,7 @@ impl<'a> LiftToSpvPtrs<'a> { .into_iter() .flatten() .collect(), - }, + ), })) } @@ -326,7 +329,8 @@ impl<'a> LiftToSpvPtrs<'a> { attrs.attrs.extend(extra_attrs); Ok(self.cx.intern(TypeDef { attrs: self.cx.intern(attrs), - kind: TypeKind::SpvInst { spv_inst: wk.OpTypeStruct.into(), type_and_const_inputs }, + kind: spv::Inst::from(wk.OpTypeStruct) + .into_canonical_type_with(&self.cx, type_and_const_inputs), })) } @@ -390,7 +394,7 @@ impl LiftToSpvPtrInstsInFunc<'_> { // FIXME(eddyb) maybe all this data should be packaged up together in a // type with fields like those of `DeferredPtrNoop` (or even more). let type_of_val_as_spv_ptr_with_layout = |v: Value| { - if let Value::DataInstOutput(v_data_inst) = v { + if let Value::DataInstOutput { inst: v_data_inst, output_idx: 0 } = v { if let Some(ptr_noop) = self.deferred_ptr_noops.get(&v_data_inst) { return Ok(( ptr_noop.output_pointer_addr_space, @@ -428,18 +432,20 @@ impl LiftToSpvPtrInstsInFunc<'_> { DataInstDef { attrs: self.lifter.strip_qptr_usage_attr(data_inst_def.attrs), form: cx.intern(DataInstFormDef { - kind: DataInstKind::SpvInst(spv::Inst { - opcode: wk.OpVariable, - imms: [spv::Imm::Short(wk.StorageClass, wk.Function)] - .into_iter() - .collect(), - }), - output_type: Some( - self.lifter.spv_ptr_type( - AddrSpace::SpvStorageClass(wk.Function), - pointee_type, - ), + kind: DataInstKind::SpvInst( + spv::Inst { + opcode: wk.OpVariable, + imms: [spv::Imm::Short(wk.StorageClass, wk.Function)] + .into_iter() + .collect(), + }, + spv::InstLowering::default(), ), + output_types: [self + .lifter + .spv_ptr_type(AddrSpace::SpvStorageClass(wk.Function), pointee_type)] + .into_iter() + .collect(), }), inputs: data_inst_def.inputs.clone(), } @@ -466,8 +472,13 @@ impl LiftToSpvPtrInstsInFunc<'_> { DataInstDef { attrs: data_inst_def.attrs, form: cx.intern(DataInstFormDef { - kind: DataInstKind::SpvInst(wk.OpAccessChain.into()), - output_type: Some(self.lifter.spv_ptr_type(addr_space, handle_type)), + kind: DataInstKind::SpvInst( + wk.OpAccessChain.into(), + spv::InstLowering::default(), + ), + output_types: [self.lifter.spv_ptr_type(addr_space, handle_type)] + .into_iter() + .collect(), }), inputs: data_inst_def.inputs.clone(), } @@ -497,7 +508,7 @@ impl LiftToSpvPtrInstsInFunc<'_> { // maybe don't even replace the `QPtrOp::Buffer` instruction? form: cx.intern(DataInstFormDef { kind: QPtrOp::BufferData.into(), - output_type: Some(type_of_val(buf_ptr)), + output_types: [type_of_val(buf_ptr)].into_iter().collect(), }), ..data_inst_def.clone() } @@ -536,13 +547,16 @@ impl LiftToSpvPtrInstsInFunc<'_> { DataInstDef { form: cx.intern(DataInstFormDef { - kind: DataInstKind::SpvInst(spv::Inst { - opcode: wk.OpArrayLength, - imms: [spv::Imm::Short(wk.LiteralInteger, field_idx)] - .into_iter() - .collect(), - }), - output_type: data_inst_form_def.output_type, + kind: DataInstKind::SpvInst( + spv::Inst { + opcode: wk.OpArrayLength, + imms: [spv::Imm::Short(wk.LiteralInteger, field_idx)] + .into_iter() + .collect(), + }, + spv::InstLowering::default(), + ), + output_types: data_inst_form_def.output_types.clone(), }), ..data_inst_def.clone() } @@ -574,7 +588,7 @@ impl LiftToSpvPtrInstsInFunc<'_> { // maybe don't even replace the `QPtrOp::Offset` instruction? form: cx.intern(DataInstFormDef { kind: QPtrOp::Offset(0).into(), - output_type: Some(type_of_val(base_ptr)), + output_types: [type_of_val(base_ptr)].into_iter().collect(), }), ..data_inst_def.clone() } @@ -656,17 +670,20 @@ impl LiftToSpvPtrInstsInFunc<'_> { DataInstDef { attrs: data_inst_def.attrs, form: cx.intern(DataInstFormDef { - kind: DataInstKind::SpvInst(wk.OpAccessChain.into()), - output_type: Some( - self.lifter.spv_ptr_type(addr_space, layout.original_type), + kind: DataInstKind::SpvInst( + wk.OpAccessChain.into(), + spv::InstLowering::default(), ), + output_types: [self.lifter.spv_ptr_type(addr_space, layout.original_type)] + .into_iter() + .collect(), }), inputs: access_chain_inputs, } } DataInstKind::QPtr(op @ (QPtrOp::Load { offset } | QPtrOp::Store { offset })) => { let (spv_opcode, access_type) = match op { - QPtrOp::Load { .. } => (wk.OpLoad, data_inst_form_def.output_type.unwrap()), + QPtrOp::Load { .. } => (wk.OpLoad, data_inst_form_def.output_types[0]), QPtrOp::Store { .. } => (wk.OpStore, type_of_val(data_inst_def.inputs[1])), _ => unreachable!(), }; @@ -689,8 +706,11 @@ impl LiftToSpvPtrInstsInFunc<'_> { let mut new_data_inst_def = DataInstDef { form: cx.intern(DataInstFormDef { - kind: DataInstKind::SpvInst(spv_opcode.into()), - output_type: data_inst_form_def.output_type, + kind: DataInstKind::SpvInst( + spv_opcode.into(), + spv::InstLowering::default(), + ), + output_types: data_inst_form_def.output_types.clone(), }), ..data_inst_def.clone() }; @@ -722,13 +742,13 @@ impl LiftToSpvPtrInstsInFunc<'_> { } new_data_inst_def.inputs[input_idx] = - Value::DataInstOutput(access_chain_data_inst); + Value::DataInstOutput { inst: access_chain_data_inst, output_idx: 0 }; } new_data_inst_def } - DataInstKind::SpvInst(_) | DataInstKind::SpvExtInst { .. } => { + DataInstKind::SpvInst(_, lowering) | DataInstKind::SpvExtInst { lowering, .. } => { let mut to_spv_ptr_input_adjustments = vec![]; let mut from_spv_ptr_output = None; for attr in &cx[data_inst_def.attrs].attrs { @@ -798,13 +818,19 @@ impl LiftToSpvPtrInstsInFunc<'_> { } new_data_inst_def.inputs[input_idx] = - Value::DataInstOutput(access_chain_data_inst); + Value::DataInstOutput { inst: access_chain_data_inst, output_idx: 0 }; } if let Some((addr_space, pointee_type)) = from_spv_ptr_output { + assert!(lowering.disaggregated_output.is_none()); + + let data_inst_form_def = &cx[new_data_inst_def.form]; + assert_eq!(data_inst_form_def.output_types.len(), 1); new_data_inst_def.form = cx.intern(DataInstFormDef { - output_type: Some(self.lifter.spv_ptr_type(addr_space, pointee_type)), - ..cx[new_data_inst_def.form].clone() + output_types: [self.lifter.spv_ptr_type(addr_space, pointee_type)] + .into_iter() + .collect(), + ..data_inst_form_def.clone() }); } @@ -834,8 +860,13 @@ impl LiftToSpvPtrInstsInFunc<'_> { Some(DataInstDef { attrs: Default::default(), form: self.lifter.cx.intern(DataInstFormDef { - kind: DataInstKind::SpvInst(wk.OpAccessChain.into()), - output_type: Some(self.lifter.spv_ptr_type(addr_space, final_pointee_type)), + kind: DataInstKind::SpvInst( + wk.OpAccessChain.into(), + spv::InstLowering::default(), + ), + output_types: [self.lifter.spv_ptr_type(addr_space, final_pointee_type)] + .into_iter() + .collect(), }), inputs: access_chain_inputs, }) @@ -1007,8 +1038,8 @@ impl LiftToSpvPtrInstsInFunc<'_> { for v in values { // FIXME(eddyb) the loop could theoretically be avoided, but that'd // make tracking use counts harder. - while let Value::DataInstOutput(data_inst) = *v { - match self.deferred_ptr_noops.get(&data_inst) { + while let Value::DataInstOutput { inst, output_idx: 0 } = *v { + match self.deferred_ptr_noops.get(&inst) { Some(ptr_noop) => { *v = ptr_noop.output_pointer; } @@ -1022,8 +1053,8 @@ impl LiftToSpvPtrInstsInFunc<'_> { // encoded as `Option` for (dense) map entry reasons. fn add_value_uses(&mut self, values: &[Value]) { for &v in values { - if let Value::DataInstOutput(data_inst) = v { - let count = self.data_inst_use_counts.entry(data_inst); + if let Value::DataInstOutput { inst, .. } = v { + let count = self.data_inst_use_counts.entry(inst); *count = Some( NonZeroU32::new(count.map_or(0, |c| c.get()).checked_add(1).unwrap()).unwrap(), ); @@ -1032,8 +1063,8 @@ impl LiftToSpvPtrInstsInFunc<'_> { } fn remove_value_uses(&mut self, values: &[Value]) { for &v in values { - if let Value::DataInstOutput(data_inst) = v { - let count = self.data_inst_use_counts.entry(data_inst); + if let Value::DataInstOutput { inst, .. } = v { + let count = self.data_inst_use_counts.entry(inst); *count = NonZeroU32::new(count.unwrap().get() - 1); } } @@ -1088,11 +1119,14 @@ impl Transformer for LiftToSpvPtrInstsInFunc<'_> { if let DataInstKind::QPtr(_) = data_inst_form_def.kind { lifted = Err(LiftError(Diag::bug(["unimplemented qptr instruction".into()]))); - } else if let Some(ty) = data_inst_form_def.output_type { - if matches!(self.lifter.cx[ty].kind, TypeKind::QPtr) { - lifted = Err(LiftError(Diag::bug([ - "unimplemented qptr-producing instruction".into(), - ]))); + } else { + for &ty in &data_inst_form_def.output_types { + if matches!(self.lifter.cx[ty].kind, TypeKind::QPtr) { + lifted = Err(LiftError(Diag::bug([ + "unimplemented qptr-producing instruction".into(), + ]))); + break; + } } } } diff --git a/src/qptr/lower.rs b/src/qptr/lower.rs index 2fe70c01..23c39559 100644 --- a/src/qptr/lower.rs +++ b/src/qptr/lower.rs @@ -163,7 +163,7 @@ impl<'a> LowerFromSpvPtrs<'a> { // (!!! may cause bad interactions with storage class inference `Generic` abuse) fn as_spv_ptr_type(&self, ty: Type) -> Option<(AddrSpace, Type)> { match &self.cx[ty].kind { - TypeKind::SpvInst { spv_inst, type_and_const_inputs } + TypeKind::SpvInst { spv_inst, type_and_const_inputs, .. } if spv_inst.opcode == self.wk.OpTypePointer => { let sc = match spv_inst.imms[..] { @@ -423,17 +423,22 @@ impl LowerFromSpvPtrInstsInFunc<'_> { let func = func_at_data_inst_frozen.at(()); let attrs = data_inst_def.attrs; - let DataInstFormDef { ref kind, output_type } = cx[data_inst_def.form]; + let DataInstFormDef { kind, output_types } = &cx[data_inst_def.form]; - let spv_inst = match kind { - DataInstKind::SpvInst(spv_inst) => spv_inst, + let (spv_inst, spv_inst_lowering) = match kind { + DataInstKind::SpvInst(spv_inst, lowering) => (spv_inst, lowering), _ => return Ok(Transformed::Unchanged), }; + // HACK(eddyb) this is for easy bailing/asserting. + let disaggregated_output_or_inputs_during_lowering = + spv_inst_lowering.disaggregated_output.is_some() + || !spv_inst_lowering.disaggregated_inputs.is_empty(); + // Flatten `QPtrOp::Offset`s behind `ptr` into a base pointer and offset. let flatten_offsets = |mut ptr| { let mut offset = 0; - while let Value::DataInstOutput(ptr_inst) = ptr { + while let Value::DataInstOutput { inst: ptr_inst, output_idx: 0 } = ptr { let ptr_inst_def = func.at(ptr_inst).def(); match cx[ptr_inst_def.form].kind { DataInstKind::QPtr(QPtrOp::Offset(ptr_offset)) => { @@ -452,9 +457,12 @@ impl LowerFromSpvPtrInstsInFunc<'_> { }; let replacement_kind_and_inputs = if spv_inst.opcode == wk.OpVariable { + assert!(!disaggregated_output_or_inputs_during_lowering); + assert_eq!(output_types.len(), 1); assert!(data_inst_def.inputs.len() <= 1); + let (_, var_data_type) = - self.lowerer.as_spv_ptr_type(output_type.unwrap()).ok_or_else(|| { + self.lowerer.as_spv_ptr_type(output_types[0]).ok_or_else(|| { LowerError(Diag::bug(["output type not an `OpTypePointer`".into()])) })?; match self.lowerer.layout_of(var_data_type)? { @@ -467,6 +475,10 @@ impl LowerFromSpvPtrInstsInFunc<'_> { _ => return Ok(Transformed::Unchanged), } } else if spv_inst.opcode == wk.OpLoad { + // FIXME(eddyb) expand into per-leaf accesses. + if disaggregated_output_or_inputs_during_lowering { + return Ok(Transformed::Unchanged); + } // FIXME(eddyb) support memory operands somehow. if !spv_inst.imms.is_empty() { return Ok(Transformed::Unchanged); @@ -479,6 +491,10 @@ impl LowerFromSpvPtrInstsInFunc<'_> { (QPtrOp::Load { offset }.into(), [ptr].into_iter().collect()) } else if spv_inst.opcode == wk.OpStore { + // FIXME(eddyb) expand into per-leaf accesses. + if disaggregated_output_or_inputs_during_lowering { + return Ok(Transformed::Unchanged); + } // FIXME(eddyb) support memory operands somehow. if !spv_inst.imms.is_empty() { return Ok(Transformed::Unchanged); @@ -492,6 +508,7 @@ impl LowerFromSpvPtrInstsInFunc<'_> { (QPtrOp::Store { offset }.into(), [ptr, value].into_iter().collect()) } else if spv_inst.opcode == wk.OpArrayLength { + assert!(!disaggregated_output_or_inputs_during_lowering); let field_idx = match spv_inst.imms[..] { [spv::Imm::Short(_, field_idx)] => field_idx, _ => unreachable!(), @@ -561,6 +578,8 @@ impl LowerFromSpvPtrInstsInFunc<'_> { ] .contains(&spv_inst.opcode) { + assert!(!disaggregated_output_or_inputs_during_lowering); + // FIXME(eddyb) avoid erasing the "inbounds" qualifier. let base_ptr = data_inst_def.inputs[0]; let (_, base_pointee_type) = @@ -572,11 +591,12 @@ impl LowerFromSpvPtrInstsInFunc<'_> { // a `OpTypeRuntimeArray`, with the original type as the element type. let access_chain_base_layout = if [wk.OpPtrAccessChain, wk.OpInBoundsPtrAccessChain].contains(&spv_inst.opcode) { - self.lowerer.layout_of(cx.intern(TypeKind::SpvInst { - spv_inst: wk.OpTypeRuntimeArray.into(), - type_and_const_inputs: + self.lowerer.layout_of(cx.intern( + spv::Inst::from(wk.OpTypeRuntimeArray).into_canonical_type_with( + cx, [TypeOrConst::Type(base_pointee_type)].into_iter().collect(), - }))? + ), + ))? } else { self.lowerer.layout_of(base_pointee_type)? }; @@ -609,7 +629,7 @@ impl LowerFromSpvPtrInstsInFunc<'_> { attrs: Default::default(), form: cx.intern(DataInstFormDef { kind, - output_type: Some(self.lowerer.qptr_type()), + output_types: [self.lowerer.qptr_type()].into_iter().collect(), }), inputs, } @@ -635,14 +655,18 @@ impl LowerFromSpvPtrInstsInFunc<'_> { _ => unreachable!(), } - ptr = Value::DataInstOutput(step_data_inst); + ptr = Value::DataInstOutput { inst: step_data_inst, output_idx: 0 }; } final_step.into_data_inst_kind_and_inputs(ptr) } else if spv_inst.opcode == wk.OpBitcast { + assert!(!disaggregated_output_or_inputs_during_lowering); + assert_eq!(output_types.len(), 1); + assert_eq!(data_inst_def.inputs.len(), 1); + let input = data_inst_def.inputs[0]; // Pointer-to-pointer casts are noops on `qptr`. if self.lowerer.as_spv_ptr_type(func.at(input).type_of(cx)).is_some() - && self.lowerer.as_spv_ptr_type(output_type.unwrap()).is_some() + && self.lowerer.as_spv_ptr_type(output_types[0]).is_some() { // HACK(eddyb) this will end added to `noop_offsets_to_base_ptr`, // which should replace all uses of this bitcast with its input. @@ -660,7 +684,7 @@ impl LowerFromSpvPtrInstsInFunc<'_> { // FIXME(eddyb) because this is now interned, it might be better to // temporarily track the old output types in a map, and not actually // intern the non-`qptr`-output `qptr.*` instructions. - form: cx.intern(DataInstFormDef { kind: new_kind, output_type }), + form: cx.intern(DataInstFormDef { kind: new_kind, output_types: output_types.clone() }), inputs: new_inputs, })) } @@ -679,19 +703,25 @@ impl LowerFromSpvPtrInstsInFunc<'_> { // FIXME(eddyb) is this a good convention? let func = func_at_data_inst_frozen.at(()); - match data_inst_form_def.kind { + let spv_inst_lowering = match &data_inst_form_def.kind { // Known semantics, no need to preserve SPIR-V pointer information. DataInstKind::Scalar(_) | DataInstKind::Vector(_) | DataInstKind::FuncCall(_) | DataInstKind::QPtr(_) => return, - DataInstKind::SpvInst(_) | DataInstKind::SpvExtInst { .. } => {} - } + DataInstKind::SpvInst(_, lowering) | DataInstKind::SpvExtInst { lowering, .. } => { + lowering + } + }; let mut old_and_new_attrs = None; let get_old_attrs = || AttrSetDef { attrs: cx[data_inst_def.attrs].attrs.clone() }; + if let Some(LowerError(e)) = extra_error { + old_and_new_attrs.get_or_insert_with(get_old_attrs).push_diag(e); + } + for (input_idx, &v) in data_inst_def.inputs.iter().enumerate() { if let Some((_, pointee)) = self.lowerer.as_spv_ptr_type(func.at(v).type_of(cx)) { old_and_new_attrs.get_or_insert_with(get_old_attrs).attrs.insert( @@ -703,8 +733,18 @@ impl LowerFromSpvPtrInstsInFunc<'_> { ); } } - if let Some(output_type) = data_inst_form_def.output_type { - if let Some((addr_space, pointee)) = self.lowerer.as_spv_ptr_type(output_type) { + for (output_idx, &ty) in data_inst_form_def.output_types.iter().enumerate() { + if let Some((addr_space, pointee)) = self.lowerer.as_spv_ptr_type(ty) { + // FIXME(eddyb) make this impossible by lowering all instructions + // that may produce aggregates with pointer leaves. + if output_idx != 0 || spv_inst_lowering.disaggregated_output.is_some() { + old_and_new_attrs.get_or_insert_with(get_old_attrs).push_diag(Diag::bug([ + format!("unsupported pointer as aggregate leaf (output #{output_idx})") + .into(), + ])); + continue; + } + old_and_new_attrs.get_or_insert_with(get_old_attrs).attrs.insert( QPtrAttr::FromSpvPtrOutput { addr_space: OrdAssertEq(addr_space), @@ -715,10 +755,6 @@ impl LowerFromSpvPtrInstsInFunc<'_> { } } - if let Some(LowerError(e)) = extra_error { - old_and_new_attrs.get_or_insert_with(get_old_attrs).push_diag(e); - } - if let Some(attrs) = old_and_new_attrs { func_at_data_inst.def().attrs = cx.intern(attrs); } @@ -728,8 +764,8 @@ impl LowerFromSpvPtrInstsInFunc<'_> { // encoded as `Option` for (dense) map entry reasons. fn add_value_uses(&mut self, values: &[Value]) { for &v in values { - if let Value::DataInstOutput(data_inst) = v { - let count = self.data_inst_use_counts.entry(data_inst); + if let Value::DataInstOutput { inst, .. } = v { + let count = self.data_inst_use_counts.entry(inst); *count = Some( NonZeroU32::new(count.map_or(0, |c| c.get()).checked_add(1).unwrap()).unwrap(), ); @@ -738,8 +774,8 @@ impl LowerFromSpvPtrInstsInFunc<'_> { } fn remove_value_uses(&mut self, values: &[Value]) { for &v in values { - if let Value::DataInstOutput(data_inst) = v { - let count = self.data_inst_use_counts.entry(data_inst); + if let Value::DataInstOutput { inst, .. } = v { + let count = self.data_inst_use_counts.entry(inst); *count = NonZeroU32::new(count.unwrap().get() - 1); } } @@ -753,7 +789,7 @@ impl Transformer for LowerFromSpvPtrInstsInFunc<'_> { let mut v = *v; let transformed = match v { - Value::DataInstOutput(inst) => self + Value::DataInstOutput { inst, output_idx: 0 } => self .noop_offsets_to_base_ptr .get(&inst) .copied() @@ -806,7 +842,11 @@ impl Transformer for LowerFromSpvPtrInstsInFunc<'_> { if let QPtrOp::Offset(0) = op { let mut base_ptr = new_def.inputs[0]; - if let Value::DataInstOutput(base_ptr_inst) = base_ptr { + if let Value::DataInstOutput { + inst: base_ptr_inst, + output_idx: 0, + } = base_ptr + { if let Some(&base_ptr_base_ptr) = self.noop_offsets_to_base_ptr.get(&base_ptr_inst) { diff --git a/src/qptr/mod.rs b/src/qptr/mod.rs index e31f82cc..65221c93 100644 --- a/src/qptr/mod.rs +++ b/src/qptr/mod.rs @@ -35,7 +35,7 @@ pub enum QPtrAttr { // FIXME(eddyb) reduce usage by modeling more of SPIR-V inside SPIR-T. ToSpvPtrInput { input_idx: u32, pointee: OrdAssertEq }, - /// When applied to a `DataInst` with a `QPtr`-typed output value, + /// When applied to a `DataInst` with a single `QPtr`-typed output value, /// this describes the original `OpTypePointer` produced by an unknown /// SPIR-V instruction (likely creating it, without deriving from an input). /// diff --git a/src/spv/canonical.rs b/src/spv/canonical.rs index ea679873..ecdd453a 100644 --- a/src/spv/canonical.rs +++ b/src/spv/canonical.rs @@ -73,6 +73,7 @@ def_mappable_ops! { } const { OpUndef, + OpConstantNull, OpConstantFalse, OpConstantTrue, OpConstant, @@ -264,7 +265,23 @@ impl spv::Inst { // FIXME(eddyb) automate bidirectional mappings more (although the need // for conditional, i.e. "partial", mappings, adds a lot of complexity). - pub(super) fn as_canonical_type( + pub fn into_canonical_type_with( + self, + cx: &Context, + type_and_const_inputs: SmallVec<[TypeOrConst; 2]>, + ) -> TypeKind { + let value_lowering = match spv::AggregateShape::compute(cx, &self, &type_and_const_inputs) { + Some(aggregate_shape) => spv::ValueLowering::Disaggregate(aggregate_shape), + None => spv::ValueLowering::Direct, + }; + if let Some(type_kind) = self.as_canonical_non_spv_type(cx, &type_and_const_inputs) { + assert!(value_lowering == spv::ValueLowering::Direct); + type_kind + } else { + TypeKind::SpvInst { spv_inst: self, type_and_const_inputs, value_lowering } + } + } + fn as_canonical_non_spv_type( &self, cx: &Context, type_and_const_inputs: &[TypeOrConst], @@ -357,6 +374,12 @@ impl spv::Inst { mo.OpUndef == self.opcode } + // HACK(eddyb) this only exists as a helper for `spv::lower`. + pub(super) fn lower_const_by_distributing_to_aggregate_leaves(&self) -> bool { + let mo = MappableOps::get(); + [mo.OpUndef, mo.OpConstantNull].contains(&self.opcode) + } + // FIXME(eddyb) automate bidirectional mappings more (although the need // for conditional, i.e. "partial", mappings, adds a lot of complexity). pub(super) fn as_canonical_const( @@ -378,6 +401,27 @@ impl spv::Inst { (_, []) if opcode == mo.OpConstant => { Some(scalar::Const::try_decode_from_spv_imms(ty.as_scalar(cx)?, imms)?.into()) } + + ([], []) if opcode == mo.OpConstantNull => { + let null_scalar = |ty: scalar::Type| { + if ty.bit_width() > 128 { + return None; + } + Some(scalar::Const::from_bits(ty, 0)) + }; + match cx[ty].kind { + TypeKind::Scalar(ty) => Some(null_scalar(ty)?.into()), + TypeKind::Vector(ty) => { + let elem = null_scalar(ty.elem)?; + Some( + vector::Const::from_elems(ty, (0..ty.elem_count.get()).map(|_| elem)) + .into(), + ) + } + _ => None, + } + } + _ if opcode == wk.OpConstantComposite => { let ty = ty.as_vector(cx)?; let elems = (const_inputs.len() == usize::from(ty.elem_count.get()) @@ -385,6 +429,7 @@ impl spv::Inst { .then(|| const_inputs.iter().map(|ct| *ct.as_scalar(cx).unwrap()))?; Some(vector::Const::from_elems(ty, elems).into()) } + _ => None, } } diff --git a/src/spv/lift.rs b/src/spv/lift.rs index 3ceabd97..2d6e7dce 100644 --- a/src/spv/lift.rs +++ b/src/spv/lift.rs @@ -6,19 +6,31 @@ use crate::visit::{InnerVisit, Visitor}; use crate::{ cfg, scalar, AddrSpace, Attr, AttrSet, Const, ConstDef, ConstKind, Context, ControlNode, ControlNodeKind, ControlNodeOutputDecl, ControlRegion, ControlRegionInputDecl, DataInst, - DataInstDef, DataInstForm, DataInstFormDef, DataInstKind, DeclDef, EntityList, ExportKey, - Exportee, Func, FuncDecl, FuncDefBody, FuncParam, FxIndexMap, FxIndexSet, GlobalVar, - GlobalVarDefBody, Import, Module, ModuleDebugInfo, ModuleDialect, SelectionKind, Type, TypeDef, - TypeKind, TypeOrConst, Value, + DataInstDef, DataInstForm, DataInstFormDef, DataInstKind, DeclDef, EntityList, + EntityOrientedDenseMap, ExportKey, Exportee, Func, FuncDecl, FuncDefBody, FuncParam, + FxIndexMap, FxIndexSet, GlobalVar, GlobalVarDefBody, Import, Module, ModuleDebugInfo, + ModuleDialect, SelectionKind, Type, TypeDef, TypeKind, TypeOrConst, Value, }; +use itertools::Itertools; use rustc_hash::FxHashMap; use smallvec::SmallVec; use std::borrow::Cow; use std::collections::BTreeMap; use std::num::NonZeroUsize; +use std::ops::Range; use std::path::Path; use std::{io, iter, mem, slice}; +// HACK(eddyb) getting around the lack of a `Step` impl on `spv::Id` (`NonZeroU32`). +trait IdRangeExt { + fn iter(&self) -> iter::Map, fn(u32) -> spv::Id>; +} +impl IdRangeExt for Range { + fn iter(&self) -> iter::Map, fn(u32) -> spv::Id> { + (self.start.get()..self.end.get()).map(|i| spv::Id::new(i).unwrap()) + } +} + impl spv::Dialect { fn capability_insts(&self) -> impl Iterator + '_ { let wk = &spec::Spec::get().well_known; @@ -75,13 +87,21 @@ impl spv::ModuleDebugInfo { } } -struct IdAllocator<'a, AI: FnMut() -> spv::Id> { +/// ID allocation callback, kept as a closure (instead of having its state +/// be part of `Lifter`) to avoid misuse. +trait AllocIds: FnMut(usize) -> Range { + fn one(&mut self) -> spv::Id { + self(1).start + } +} + +impl Range> AllocIds for F {} + +struct Lifter<'a, AI: AllocIds> { cx: &'a Context, module: &'a Module, - /// ID allocation callback, kept as a closure (instead of having its state - /// be part of `IdAllocator`) to avoid misuse. - alloc_id: AI, + alloc_ids: AI, ids: ModuleIds<'a>, @@ -106,7 +126,6 @@ enum Global { Const(Const), } -// FIXME(eddyb) should this use ID ranges instead of `SmallVec<[spv::Id; 4]>`? // FIXME(eddyb) this is inconsistently named with `FuncBodyLifting`. struct FuncIds<'a> { spv_func_ret_type: Type, @@ -115,12 +134,12 @@ struct FuncIds<'a> { spv_func_type: Type, func_id: spv::Id, - param_ids: SmallVec<[spv::Id; 4]>, + param_ids: Range, body: Option>, } -impl spv::Id> Visitor<'_> for IdAllocator<'_, AI> { +impl Visitor<'_> for Lifter<'_, AI> { fn visit_attr_set_use(&mut self, attrs: AttrSet) { self.visit_attr_set_def(&self.cx[attrs]); } @@ -162,7 +181,7 @@ impl spv::Id> Visitor<'_> for IdAllocator<'_, AI> { } self.visit_type_def(ty_def); - self.ids.globals.insert(global, (self.alloc_id)()); + self.ids.globals.insert(global, self.alloc_ids.one()); } fn visit_const_use(&mut self, ct: Const) { let global = Global::Const(ct); @@ -188,7 +207,7 @@ impl spv::Id> Visitor<'_> for IdAllocator<'_, AI> { | ConstKind::PtrToGlobalVar(_) | ConstKind::SpvInst { .. } => { self.visit_const_def(ct_def); - self.ids.globals.insert(global, (self.alloc_id)()); + self.ids.globals.insert(global, self.alloc_ids.one()); } // HACK(eddyb) because this is an `OpString` and needs to go earlier @@ -206,7 +225,7 @@ impl spv::Id> Visitor<'_> for IdAllocator<'_, AI> { } ); - self.ids.debug_strings.entry(&self.cx[s]).or_insert_with(&mut self.alloc_id); + self.ids.debug_strings.entry(&self.cx[s]).or_insert_with(|| self.alloc_ids.one()); } } } @@ -229,14 +248,26 @@ impl spv::Id> Visitor<'_> for IdAllocator<'_, AI> { // Synthesize an `OpTypeFunction` type (that SPIR-T itself doesn't carry). let wk = &spec::Spec::get().well_known; - let spv_func_ret_type = func_decl.ret_type; - let spv_func_type = self.cx.intern(TypeKind::SpvInst { - spv_inst: wk.OpTypeFunction.into(), - type_and_const_inputs: iter::once(spv_func_ret_type) - .chain(func_decl.params.iter().map(|param| param.ty)) - .map(TypeOrConst::Type) - .collect(), - }); + let spv_func_ret_type = match &func_decl.ret_types[..] { + &[ty] => ty, + // Reaggregate multiple return types into an `OpTypeStruct`. + ret_types => { + let opcode = if ret_types.is_empty() { wk.OpTypeVoid } else { wk.OpTypeStruct }; + self.cx.intern(spv::Inst::from(opcode).into_canonical_type_with( + self.cx, + ret_types.iter().copied().map(TypeOrConst::Type).collect(), + )) + } + }; + let spv_func_type = self.cx.intern( + spv::Inst::from(wk.OpTypeFunction).into_canonical_type_with( + self.cx, + iter::once(spv_func_ret_type) + .chain(func_decl.params.iter().map(|param| param.ty)) + .map(TypeOrConst::Type) + .collect(), + ), + ); self.visit_type_use(spv_func_type); // NOTE(eddyb) inserting first produces a different function ordering @@ -247,15 +278,16 @@ impl spv::Id> Visitor<'_> for IdAllocator<'_, AI> { FuncIds { spv_func_ret_type, spv_func_type, - func_id: (self.alloc_id)(), - param_ids: func_decl.params.iter().map(|_| (self.alloc_id)()).collect(), + func_id: self.alloc_ids.one(), + param_ids: (self.alloc_ids)(func_decl.params.len()), body: None, }, ); self.visit_func_decl(func_decl); - // Handle the body last, to minimize recursion hazards (see comment above). + // Handle the body last, to minimize recursion hazards (see comment above), + // and to allow `FuncBodyLifting` to look up its dependencies in `self.ids`. match &func_decl.def { DeclDef::Imported(_) => {} DeclDef::Present(func_def_body) => { @@ -269,7 +301,7 @@ impl spv::Id> Visitor<'_> for IdAllocator<'_, AI> { for sources in debug_info.source_languages.values() { // The file operand of `OpSource` has to point to an `OpString`. for &s in sources.file_contents.keys() { - self.ids.debug_strings.entry(&self.cx[s]).or_insert_with(&mut self.alloc_id); + self.ids.debug_strings.entry(&self.cx[s]).or_insert_with(|| self.alloc_ids.one()); } } } @@ -283,7 +315,7 @@ impl spv::Id> Visitor<'_> for IdAllocator<'_, AI> { self.ids .debug_strings .entry(&self.cx[file_path.0]) - .or_insert_with(&mut self.alloc_id); + .or_insert_with(|| self.alloc_ids.one()); } } attr.inner_visit_with(self); @@ -300,13 +332,13 @@ impl spv::Id> Visitor<'_> for IdAllocator<'_, AI> { DataInstKind::Scalar(_) | DataInstKind::Vector(_) | DataInstKind::FuncCall(_) - | DataInstKind::SpvInst(_) => {} + | DataInstKind::SpvInst(..) => {} DataInstKind::SpvExtInst { ext_set, .. } => { self.ids .ext_inst_imports .entry(&self.cx[ext_set]) - .or_insert_with(&mut self.alloc_id); + .or_insert_with(|| self.alloc_ids.one()); } } data_inst_form_def.inner_visit_with(self); @@ -315,10 +347,8 @@ impl spv::Id> Visitor<'_> for IdAllocator<'_, AI> { // FIXME(eddyb) this is inconsistently named with `FuncIds`. struct FuncBodyLifting<'a> { - // FIXME(eddyb) use `EntityOrientedDenseMap` here. - region_inputs_source: FxHashMap, - // FIXME(eddyb) use `EntityOrientedDenseMap` here. - data_inst_output_ids: FxHashMap, + region_inputs_source: EntityOrientedDenseMap, + data_insts: EntityOrientedDenseMap, label_ids: FxHashMap, blocks: FxIndexMap>, @@ -334,6 +364,40 @@ enum RegionInputsSource { LoopHeaderPhis(ControlNode), } +struct DataInstLifting { + result_id: Option, + + /// If the SPIR-V result type is "aggregate" (`OpTypeStruct`/`OpTypeArray`), + /// this describes how to extract its leaves, which is necessary as on the + /// SPIR-T side, [`Value::DataInstOutput`] can only refer to individual leaves. + disaggregate_result: Option, + + /// `reaggregate_inputs[i]` describes how to recreate the "aggregate" value + /// demanded by [`spv::InstLowering`]'s `disaggregated_inputs[i]`. + reaggregate_inputs: SmallVec<[ReaggregateFromLeaves; 1]>, +} + +/// All the information necessary to decompose a SPIR-V "aggregate" value into +/// its leaves, with one `OpCompositeExtract` per leaf. +// +// FIXME(eddyb) it might be more efficient to only extract actually used leaves, +// or chain partial extracts following nesting structure - but this is simpler. +struct DisaggregateToLeaves { + op_composite_extract_result_ids: Range, +} + +/// All the information necessary to recreate a SPIR-V "aggregate" value, with +/// one `OpCompositeInsert` per leaf (starting with an `OpUndef` of that type). +// +// FIXME(eddyb) it might be more efficient to use other strategies, such as +// `OpCompositeConstruct`, special-casing constants, reusing whole results +// of other `DataInstDef`s with an aggregate result, etc. - but this is simpler +// for now, and it reuses the "one instruction per leaf" used for extractions. +struct ReaggregateFromLeaves { + op_undef: Const, + op_composite_insert_result_ids: Range, +} + /// Any of the possible points in structured or unstructured SPIR-T control-flow, /// that may require a separate SPIR-V basic block. #[derive(Copy, Clone, PartialEq, Eq, Hash)] @@ -377,6 +441,11 @@ struct Terminator<'a> { kind: Cow<'a, cfg::ControlInstKind>, + /// If this is a [`cfg::ControlInstKind::Return`] with `inputs.len() > 1`, + /// this ID is for the `OpCompositeConstruct` needed to produce the single + /// `OpTypStruct` (`spv_func_ret_type`) value required by `OpReturnValue`. + reaggregated_return_value_id: Option, + // FIXME(eddyb) use `Cow` or something, but ideally the "owned" case always // has at most one input, so allocating a whole `Vec` for that seems unwise. inputs: SmallVec<[Value; 2]>, @@ -543,12 +612,13 @@ impl<'a> FuncAt<'a, ControlNode> { impl<'a> FuncBodyLifting<'a> { fn from_func_def_body( - id_allocator: &mut IdAllocator<'_, impl FnMut() -> spv::Id>, + lifter: &mut Lifter<'_, impl AllocIds>, func_def_body: &'a FuncDefBody, ) -> Self { - let cx = id_allocator.cx; + let wk = &spec::Spec::get().well_known; + let cx = lifter.cx; - let mut region_inputs_source = FxHashMap::default(); + let mut region_inputs_source = EntityOrientedDenseMap::new(); region_inputs_source.insert(func_def_body.body, RegionInputsSource::FuncParams); // Create a SPIR-V block for every CFG point needing one. @@ -558,7 +628,7 @@ impl<'a> FuncBodyLifting<'a> { let phis = match point { CfgPoint::RegionEntry(region) => { - if region_inputs_source.contains_key(®ion) { + if region_inputs_source.get(region).is_some() { // Region inputs handled by the parent of the region. SmallVec::new() } else { @@ -571,7 +641,7 @@ impl<'a> FuncBodyLifting<'a> { attrs, ty, - result_id: (id_allocator.alloc_id)(), + result_id: lifter.alloc_ids.one(), cases: FxIndexMap::default(), default_value: None, }) @@ -603,7 +673,7 @@ impl<'a> FuncBodyLifting<'a> { attrs, ty, - result_id: (id_allocator.alloc_id)(), + result_id: lifter.alloc_ids.one(), cases: FxIndexMap::default(), default_value: Some(initial_inputs[i]), }) @@ -621,7 +691,7 @@ impl<'a> FuncBodyLifting<'a> { attrs, ty, - result_id: (id_allocator.alloc_id)(), + result_id: lifter.alloc_ids.one(), cases: FxIndexMap::default(), default_value: None, }) @@ -652,6 +722,12 @@ impl<'a> FuncBodyLifting<'a> { Terminator { attrs: *attrs, kind: Cow::Borrowed(kind), + reaggregated_return_value_id: match kind { + cfg::ControlInstKind::Return if inputs.len() > 1 => { + Some(lifter.alloc_ids.one()) + } + _ => None, + }, // FIXME(eddyb) borrow these whenever possible. inputs: inputs.clone(), targets: targets @@ -669,10 +745,16 @@ impl<'a> FuncBodyLifting<'a> { } else { // Structured return out of the function body. assert!(region == func_def_body.body); + let inputs = func_def_body.at_body().def().outputs.clone(); Terminator { attrs: AttrSet::default(), kind: Cow::Owned(cfg::ControlInstKind::Return), - inputs: func_def_body.at_body().def().outputs.clone(), + reaggregated_return_value_id: if inputs.len() > 1 { + Some(lifter.alloc_ids.one()) + } else { + None + }, + inputs, targets: [].into_iter().collect(), target_phi_values: FxIndexMap::default(), merge: None, @@ -691,6 +773,7 @@ impl<'a> FuncBodyLifting<'a> { ControlNodeKind::Select { kind, scrutinee, cases } => Terminator { attrs: AttrSet::default(), kind: Cow::Owned(cfg::ControlInstKind::SelectBranch(kind.clone())), + reaggregated_return_value_id: None, inputs: [*scrutinee].into_iter().collect(), targets: cases .iter() @@ -704,6 +787,7 @@ impl<'a> FuncBodyLifting<'a> { Terminator { attrs: AttrSet::default(), kind: Cow::Owned(cfg::ControlInstKind::Branch), + reaggregated_return_value_id: None, inputs: [].into_iter().collect(), targets: [CfgPoint::RegionEntry(*body)].into_iter().collect(), target_phi_values: FxIndexMap::default(), @@ -742,6 +826,7 @@ impl<'a> FuncBodyLifting<'a> { ControlNodeKind::Select { .. } => Terminator { attrs: AttrSet::default(), kind: Cow::Owned(cfg::ControlInstKind::Branch), + reaggregated_return_value_id: None, inputs: [].into_iter().collect(), targets: [parent_exit].into_iter().collect(), target_phi_values: region_outputs @@ -768,6 +853,7 @@ impl<'a> FuncBodyLifting<'a> { Terminator { attrs: AttrSet::default(), kind: Cow::Owned(cfg::ControlInstKind::Branch), + reaggregated_return_value_id: None, inputs: [].into_iter().collect(), targets: [backedge].into_iter().collect(), target_phi_values, @@ -779,6 +865,7 @@ impl<'a> FuncBodyLifting<'a> { kind: Cow::Owned(cfg::ControlInstKind::SelectBranch( SelectionKind::BoolCond, )), + reaggregated_return_value_id: None, inputs: [repeat_condition].into_iter().collect(), targets: [backedge, parent_exit].into_iter().collect(), target_phi_values, @@ -794,6 +881,7 @@ impl<'a> FuncBodyLifting<'a> { (_, Some(succ_cursor)) => Terminator { attrs: AttrSet::default(), kind: Cow::Owned(cfg::ControlInstKind::Branch), + reaggregated_return_value_id: None, inputs: [].into_iter().collect(), targets: [succ_cursor.point].into_iter().collect(), target_phi_values: FxIndexMap::default(), @@ -865,11 +953,19 @@ impl<'a> FuncBodyLifting<'a> { let BlockLifting { terminator: original_terminator, .. } = &blocks[block_idx]; let is_trivial_branch = { - let Terminator { attrs, kind, inputs, targets, target_phi_values, merge } = - original_terminator; + let Terminator { + attrs, + kind, + reaggregated_return_value_id, + inputs, + targets, + target_phi_values, + merge, + } = original_terminator; *attrs == AttrSet::default() && matches!(**kind, cfg::ControlInstKind::Branch) + && reaggregated_return_value_id.is_none() && inputs.is_empty() && targets.len() == 1 && target_phi_values.is_empty() @@ -896,6 +992,7 @@ impl<'a> FuncBodyLifting<'a> { Terminator { attrs: Default::default(), kind: Cow::Owned(cfg::ControlInstKind::Unreachable), + reaggregated_return_value_id: None, inputs: Default::default(), targets: Default::default(), target_phi_values: Default::default(), @@ -950,27 +1047,104 @@ impl<'a> FuncBodyLifting<'a> { } } - let all_insts_with_output = blocks + let mut data_insts = EntityOrientedDenseMap::new(); + let all_func_at_data_insts = blocks .values() .flat_map(|block| block.insts.iter().copied()) - .flat_map(|insts| func_def_body.at(insts)) - .filter(|&func_at_inst| cx[func_at_inst.def().form].output_type.is_some()) - .map(|func_at_inst| func_at_inst.position); + .flat_map(|insts| func_def_body.at(insts)); + for func_at_inst in all_func_at_data_insts { + let data_inst_form_def = &cx[func_at_inst.def().form]; + + let mut new_spv_inst_lowering = spv::InstLowering::default(); + let spv_inst_lowering = match &data_inst_form_def.kind { + // Disallowed while visiting. + DataInstKind::QPtr(_) => unreachable!(), + + DataInstKind::Scalar(_) | DataInstKind::Vector(_) => { + // FIXME(eddyb) deduplicate creating this `OpTypeStruct`. + if data_inst_form_def.output_types.len() > 1 { + let tuple_ty = cx.intern( + spv::Inst::from(wk.OpTypeStruct).into_canonical_type_with( + cx, + data_inst_form_def + .output_types + .iter() + .copied() + .map(TypeOrConst::Type) + .collect(), + ), + ); + lifter.visit_type_use(tuple_ty); + new_spv_inst_lowering.disaggregated_output = Some(tuple_ty); + } + &new_spv_inst_lowering + } + + DataInstKind::FuncCall(callee) => { + if data_inst_form_def.output_types.len() > 1 { + new_spv_inst_lowering.disaggregated_output = + Some(lifter.ids.funcs[callee].spv_func_ret_type); + } + &new_spv_inst_lowering + } + + DataInstKind::SpvInst(_, lowering) | DataInstKind::SpvExtInst { lowering, .. } => { + lowering + } + }; + + let reaggregate_inputs = spv_inst_lowering + .disaggregated_inputs + .iter() + .map(|&(_, ty)| { + let op_undef = cx.intern(ConstDef { + attrs: AttrSet::default(), + ty, + kind: ConstKind::Undef, + }); + lifter.visit_const_use(op_undef); + let op_composite_insert_result_ids = + (lifter.alloc_ids)(cx[ty].disaggregated_leaf_count()); + ReaggregateFromLeaves { op_undef, op_composite_insert_result_ids } + }) + .collect(); + + // `OpFunctionCall always has a result (but may be `OpTypeVoid`-typed). + let has_result = matches!(data_inst_form_def.kind, DataInstKind::FuncCall(_)) + || spv_inst_lowering.disaggregated_output.is_some() + || !data_inst_form_def.output_types.is_empty(); + let result_id = if has_result { Some(lifter.alloc_ids.one()) } else { None }; + + let disaggregate_result = + spv_inst_lowering.disaggregated_output.map(|ty| DisaggregateToLeaves { + op_composite_extract_result_ids: (lifter.alloc_ids)( + cx[ty].disaggregated_leaf_count(), + ), + }); + + data_insts.insert( + func_at_inst.position, + DataInstLifting { result_id, disaggregate_result, reaggregate_inputs }, + ); + } Self { region_inputs_source, - data_inst_output_ids: all_insts_with_output - .map(|inst| (inst, (id_allocator.alloc_id)())) - .collect(), - label_ids: blocks.keys().map(|&point| (point, (id_allocator.alloc_id)())).collect(), + data_insts, + + label_ids: blocks.keys().map(|&point| (point, lifter.alloc_ids.one())).collect(), blocks, } } } /// Maybe-decorated "lazy" SPIR-V instruction, allowing separately emitting -/// decorations from attributes, and the instruction itself, without eagerly -/// allocating all the instructions. +/// *both* decorations (from certain [`Attr`]s), *and* the instruction itself, +/// without eagerly allocating all the instructions. +/// +/// Note that SPIR-T disaggregating SPIR-V `OpTypeStruct`/`OpTypeArray`s values +/// may require additional [`spv::Inst`]s for each `LazyInst`, either for +/// reaggregating inputs, or taking apart aggregate outputs. #[derive(Copy, Clone)] enum LazyInst<'a, 'b> { Global(Global), @@ -991,9 +1165,10 @@ enum LazyInst<'a, 'b> { }, DataInst { parent_func_ids: &'b FuncIds<'a>, - result_id: Option, data_inst_def: &'a DataInstDef, + data_inst_lifting: &'b DataInstLifting, }, + // FIXME(eddyb) should merge instructions be generated by `Terminator`? Merge(Merge), Terminator { parent_func_ids: &'b FuncIds<'a>, @@ -1002,6 +1177,15 @@ enum LazyInst<'a, 'b> { OpFunctionEnd, } +/// [`Attr::SpvDebugLine`], extracted from [`AttrSet`], and used for emitting +/// `OpLine`/`OpNoLine` SPIR-V instructions. +#[derive(Copy, Clone, PartialEq, Eq)] +struct SpvDebugLine { + file_path_id: spv::Id, + line: u32, + col: u32, +} + impl LazyInst<'_, '_> { fn result_id_attrs_and_import( self, @@ -1049,8 +1233,8 @@ impl LazyInst<'_, '_> { Self::OpFunctionParameter { param_id, param } => (Some(param_id), param.attrs, None), Self::OpLabel { label_id } => (Some(label_id), AttrSet::default(), None), Self::OpPhi { parent_func_ids: _, phi } => (Some(phi.result_id), phi.attrs, None), - Self::DataInst { parent_func_ids: _, result_id, data_inst_def } => { - (result_id, data_inst_def.attrs, None) + Self::DataInst { parent_func_ids: _, data_inst_def, data_inst_lifting } => { + (data_inst_lifting.result_id, data_inst_def.attrs, None) } Self::Merge(_) => (None, AttrSet::default(), None), Self::Terminator { parent_func_ids: _, terminator } => (None, terminator.attrs, None), @@ -1058,11 +1242,16 @@ impl LazyInst<'_, '_> { } } - fn to_inst_and_attrs( + /// Expand this `LazyInst` to one or more (see disaggregation/reaggregation + /// note in [`LazyInst`]'s doc comment for when it can be more than one) + /// [`spv::Inst`]s (with their respective [`SpvDebugLine`]s, if applicable), + /// with `each_spv_inst_with_debug_line` being called for each one. + fn for_each_spv_inst_with_debug_line( self, module: &Module, ids: &ModuleIds<'_>, - ) -> (spv::InstWithIds, AttrSet) { + mut each_spv_inst_with_debug_line: impl FnMut(spv::InstWithIds, Option), + ) { let wk = &spec::Spec::get().well_known; let cx = module.cx_ref(); @@ -1075,8 +1264,16 @@ impl LazyInst<'_, '_> { Value::ControlRegionInput { region, input_idx } => { let input_idx = usize::try_from(input_idx).unwrap(); let parent_func_body_lifting = parent_func_ids.body.as_ref().unwrap(); - match parent_func_body_lifting.region_inputs_source.get(®ion) { - Some(RegionInputsSource::FuncParams) => parent_func_ids.param_ids[input_idx], + match parent_func_body_lifting.region_inputs_source.get(region) { + Some(RegionInputsSource::FuncParams) => { + let param_id = parent_func_ids + .param_ids + .start + .checked_add(input_idx.try_into().unwrap()) + .unwrap(); + assert!(parent_func_ids.param_ids.contains(¶m_id)); + param_id + } Some(&RegionInputsSource::LoopHeaderPhis(loop_node)) => { parent_func_body_lifting.blocks[&CfgPoint::ControlNodeEntry(loop_node)].phis [input_idx] @@ -1095,14 +1292,42 @@ impl LazyInst<'_, '_> { .phis[usize::try_from(output_idx).unwrap()] .result_id } - Value::DataInstOutput(inst) => { - parent_func_ids.body.as_ref().unwrap().data_inst_output_ids[&inst] + Value::DataInstOutput { inst, output_idx } => { + let output_idx = usize::try_from(output_idx).unwrap(); + let data_inst_lifting = &parent_func_ids.body.as_ref().unwrap().data_insts[inst]; + if let Some(disaggregate_result) = &data_inst_lifting.disaggregate_result { + let result_id = disaggregate_result + .op_composite_extract_result_ids + .start + .checked_add(output_idx.try_into().unwrap()) + .unwrap(); + assert!( + disaggregate_result.op_composite_extract_result_ids.contains(&result_id) + ); + result_id + } else { + assert_eq!(output_idx, 0); + data_inst_lifting.result_id.unwrap() + } } }; let (result_id, attrs, _) = self.result_id_attrs_and_import(module, ids); - let inst = match self { - Self::Global(global) => match global { + + // FIXME(eddyb) make this less of a search and more of a + // lookup by splitting attrs into key and value parts. + let spv_debug_line = cx[attrs].attrs.iter().find_map(|attr| match *attr { + Attr::SpvDebugLine { file_path, line, col } => { + Some(SpvDebugLine { file_path_id: ids.debug_strings[&cx[file_path.0]], line, col }) + } + _ => None, + }); + + // HACK(eddyb) there is no need to allow `spv_debug_line` to vary per-inst. + let mut each_inst = |inst| each_spv_inst_with_debug_line(inst, spv_debug_line); + + match self { + Self::Global(global) => each_inst(match global { Global::Type(ty) => { let ty_def = &cx[ty]; match spv::Inst::from_canonical_type(cx, &ty_def.kind) @@ -1114,7 +1339,7 @@ impl LazyInst<'_, '_> { } Ok((spv_inst, type_and_const_inputs)) - | Err(TypeKind::SpvInst { spv_inst, type_and_const_inputs }) => { + | Err(TypeKind::SpvInst { spv_inst, type_and_const_inputs, .. }) => { spv::InstWithIds { without_ids: spv_inst.clone(), result_type_id: None, @@ -1205,7 +1430,7 @@ impl LazyInst<'_, '_> { Err(ConstKind::SpvStringLiteralForExtInst(_)) => unreachable!(), } } - }, + }), Self::OpFunction { func_decl: _, func_ids } => { // FIXME(eddyb) make this less of a search and more of a // lookup by splitting attrs into key and value parts. @@ -1222,7 +1447,7 @@ impl LazyInst<'_, '_> { }) .unwrap_or(0); - spv::InstWithIds { + each_inst(spv::InstWithIds { without_ids: spv::Inst { opcode: wk.OpFunction, imms: iter::once(spv::Imm::Short(wk.FunctionControl, func_ctrl)).collect(), @@ -1230,21 +1455,21 @@ impl LazyInst<'_, '_> { result_type_id: Some(ids.globals[&Global::Type(func_ids.spv_func_ret_type)]), result_id, ids: iter::once(ids.globals[&Global::Type(func_ids.spv_func_type)]).collect(), - } + }); } - Self::OpFunctionParameter { param_id: _, param } => spv::InstWithIds { + Self::OpFunctionParameter { param_id: _, param } => each_inst(spv::InstWithIds { without_ids: wk.OpFunctionParameter.into(), result_type_id: Some(ids.globals[&Global::Type(param.ty)]), result_id, ids: [].into_iter().collect(), - }, - Self::OpLabel { label_id: _ } => spv::InstWithIds { + }), + Self::OpLabel { label_id: _ } => each_inst(spv::InstWithIds { without_ids: wk.OpLabel.into(), result_type_id: None, result_id, ids: [].into_iter().collect(), - }, - Self::OpPhi { parent_func_ids, phi } => spv::InstWithIds { + }), + Self::OpPhi { parent_func_ids, phi } => each_inst(spv::InstWithIds { without_ids: wk.OpPhi.into(), result_type_id: Some(ids.globals[&Global::Type(phi.ty)]), result_id: Some(phi.result_id), @@ -1258,46 +1483,148 @@ impl LazyInst<'_, '_> { ] }) .collect(), - }, - Self::DataInst { parent_func_ids, result_id: _, data_inst_def } => { - let DataInstFormDef { kind, output_type } = &cx[data_inst_def.form]; - let (inst, extra_initial_id_operand) = - match spv::Inst::from_canonical_data_inst_kind(kind).ok_or(kind) { - Ok(spv_inst) => (spv_inst, None), - - Err(DataInstKind::Scalar(_) | DataInstKind::Vector(_)) => { - unreachable!("should've been handled as canonical") + }), + Self::DataInst { parent_func_ids, data_inst_def, data_inst_lifting } => { + let DataInstFormDef { kind, output_types } = &cx[data_inst_def.form]; + + let mut id_operands = SmallVec::new(); + + let mut new_spv_inst_lowering = spv::InstLowering::default(); + let mut override_result_type = None; + let (inst, spv_inst_lowering) = match spv::Inst::from_canonical_data_inst_kind(kind) + .ok_or(kind) + { + Ok(spv_inst) => { + // FIXME(eddyb) deduplicate creating this `OpTypeStruct`. + if output_types.len() > 1 { + new_spv_inst_lowering.disaggregated_output = Some(cx.intern( + spv::Inst::from(wk.OpTypeStruct).into_canonical_type_with( + cx, + output_types.iter().copied().map(TypeOrConst::Type).collect(), + ), + )); } + (spv_inst, &new_spv_inst_lowering) + } + + Err(DataInstKind::Scalar(_) | DataInstKind::Vector(_)) => { + unreachable!("should've been handled as canonical") + } - // Disallowed while visiting. - Err(DataInstKind::QPtr(_)) => unreachable!(), + // Disallowed while visiting. + Err(DataInstKind::QPtr(_)) => unreachable!(), - Err(&DataInstKind::FuncCall(callee)) => { - (wk.OpFunctionCall.into(), Some(ids.funcs[&callee].func_id)) + // `OpFunctionCall always has a result (but may be `OpTypeVoid`-typed). + Err(DataInstKind::FuncCall(callee)) => { + let callee_ids = &ids.funcs[callee]; + override_result_type = Some(callee_ids.spv_func_ret_type); + if output_types.len() > 1 { + new_spv_inst_lowering.disaggregated_output = override_result_type; } - Err(DataInstKind::SpvInst(inst)) => (inst.clone(), None), - Err(&DataInstKind::SpvExtInst { ext_set, inst }) => ( + id_operands.push(callee_ids.func_id); + (wk.OpFunctionCall.into(), &new_spv_inst_lowering) + } + Err(DataInstKind::SpvInst(inst, lowering)) => (inst.clone(), lowering), + Err(DataInstKind::SpvExtInst { ext_set, inst, lowering }) => { + id_operands.push(ids.ext_inst_imports[&cx[*ext_set]]); + ( spv::Inst { opcode: wk.OpExtInst, - imms: iter::once(spv::Imm::Short(wk.LiteralExtInstInteger, inst)) + imms: [spv::Imm::Short(wk.LiteralExtInstInteger, *inst)] + .into_iter() .collect(), }, - Some(ids.ext_inst_imports[&cx[ext_set]]), - ), + lowering, + ) + } + }; + + let int_imm = |i| spv::Imm::Short(wk.LiteralInteger, i); + + // Emit any `OpCompositeInsert`s needed by the inputs, first, + // while gathering the `id_operands` for the instruction itself. + let mut reaggregate_inputs = data_inst_lifting.reaggregate_inputs.iter(); + for id_operand in spv_inst_lowering.reaggreate_inputs(&data_inst_def.inputs) { + let value_to_id = |v| value_to_id(parent_func_ids, v); + let id_operand = match id_operand { + spv::ReaggregatedIdOperand::Direct(v) => value_to_id(v), + spv::ReaggregatedIdOperand::Aggregate { ty, leaves } => { + let result_type_id = Some(ids.globals[&Global::Type(ty)]); + + let ReaggregateFromLeaves { op_undef, op_composite_insert_result_ids } = + reaggregate_inputs.next().unwrap(); + let mut aggregate_id = ids.globals[&Global::Const(*op_undef)]; + let leaf_paths = ty + .disaggregated_leaf_types(cx) + .map_with_parent_component_path(|_, leaf_path| { + leaf_path.iter().map(|&(_, i)| i).map(int_imm).collect() + }); + for ((leaf_path_imms, op_composite_insert_result_id), &leaf_value) in + leaf_paths + .zip_eq(op_composite_insert_result_ids.iter()) + .zip_eq(leaves) + { + each_inst(spv::InstWithIds { + without_ids: spv::Inst { + opcode: wk.OpCompositeInsert, + imms: leaf_path_imms, + }, + result_type_id, + result_id: Some(op_composite_insert_result_id), + ids: [value_to_id(leaf_value), aggregate_id] + .into_iter() + .collect(), + }); + aggregate_id = op_composite_insert_result_id; + } + aggregate_id + } }; - spv::InstWithIds { + id_operands.push(id_operand); + } + assert!(reaggregate_inputs.next().is_none()); + + let result_type = + override_result_type.or(spv_inst_lowering.disaggregated_output).or_else(|| { + assert!(output_types.len() <= 1); + output_types.first().copied() + }); + each_inst(spv::InstWithIds { without_ids: inst, - result_type_id: output_type.map(|ty| ids.globals[&Global::Type(ty)]), + result_type_id: result_type.map(|ty| ids.globals[&Global::Type(ty)]), result_id, - ids: extra_initial_id_operand - .into_iter() - .chain( - data_inst_def.inputs.iter().map(|&v| value_to_id(parent_func_ids, v)), - ) - .collect(), + ids: id_operands, + }); + + // Emit any `OpCompositeExtract`s needed for the result, last. + if let Some(DisaggregateToLeaves { op_composite_extract_result_ids }) = + &data_inst_lifting.disaggregate_result + { + let aggregate_id = result_id.unwrap(); + let leaf_types_and_paths = spv_inst_lowering + .disaggregated_output + .unwrap() + .disaggregated_leaf_types(cx) + .map_with_parent_component_path(|leaf_type, leaf_path| { + (leaf_type, leaf_path.iter().map(|&(_, i)| i).map(int_imm).collect()) + }); + for ((leaf_type, leaf_path_imms), op_composite_extract_result_id) in + leaf_types_and_paths.zip_eq(op_composite_extract_result_ids.iter()) + { + each_inst(spv::InstWithIds { + without_ids: spv::Inst { + opcode: wk.OpCompositeExtract, + imms: leaf_path_imms, + }, + result_type_id: Some(ids.globals[&Global::Type(leaf_type)]), + result_id: Some(op_composite_extract_result_id), + ids: [aggregate_id].into_iter().collect(), + }); + } } } - Self::Merge(Merge::Selection(merge_label_id)) => spv::InstWithIds { + // FIXME(eddyb) should merge instructions be generated by `Terminator`? + Self::Merge(Merge::Selection(merge_label_id)) => each_inst(spv::InstWithIds { without_ids: spv::Inst { opcode: wk.OpSelectionMerge, imms: [spv::Imm::Short(wk.SelectionControl, 0)].into_iter().collect(), @@ -1305,11 +1632,11 @@ impl LazyInst<'_, '_> { result_type_id: None, result_id: None, ids: [merge_label_id].into_iter().collect(), - }, + }), Self::Merge(Merge::Loop { loop_merge: merge_label_id, loop_continue: continue_label_id, - }) => spv::InstWithIds { + }) => each_inst(spv::InstWithIds { without_ids: spv::Inst { opcode: wk.OpLoopMerge, imms: [spv::Imm::Short(wk.LoopControl, 0)].into_iter().collect(), @@ -1317,10 +1644,10 @@ impl LazyInst<'_, '_> { result_type_id: None, result_id: None, ids: [merge_label_id, continue_label_id].into_iter().collect(), - }, + }), Self::Terminator { parent_func_ids, terminator } => { let parent_func_body_lifting = parent_func_ids.body.as_ref().unwrap(); - let mut ids: SmallVec<[_; 4]> = terminator + let mut id_operands = terminator .inputs .iter() .map(|&v| value_to_id(parent_func_ids, v)) @@ -1332,6 +1659,23 @@ impl LazyInst<'_, '_> { ) .collect(); + if let Some(reaggregated_value_id) = terminator.reaggregated_return_value_id { + assert!( + matches!(*terminator.kind, cfg::ControlInstKind::Return) + && terminator.inputs.len() > 1 + ); + + each_inst(spv::InstWithIds { + without_ids: wk.OpCompositeConstruct.into(), + result_type_id: Some( + ids.globals[&Global::Type(parent_func_ids.spv_func_ret_type)], + ), + result_id: Some(reaggregated_value_id), + ids: id_operands, + }); + id_operands = [reaggregated_value_id].into_iter().collect(); + } + // FIXME(eddyb) move some of this to `spv::canonical`. let inst = match &*terminator.kind { cfg::ControlInstKind::Unreachable => wk.OpUnreachable.into(), @@ -1339,6 +1683,8 @@ impl LazyInst<'_, '_> { if terminator.inputs.is_empty() { wk.OpReturn.into() } else { + // Multiple return values get reaggregated above. + assert_eq!(id_operands.len(), 1); wk.OpReturnValue.into() } } @@ -1353,8 +1699,8 @@ impl LazyInst<'_, '_> { } cfg::ControlInstKind::SelectBranch(SelectionKind::Switch { case_consts }) => { // HACK(eddyb) move the default case from last back to first. - let default_target = ids.pop().unwrap(); - ids.insert(1, default_target); + let default_target = id_operands.pop().unwrap(); + id_operands.insert(1, default_target); spv::Inst { opcode: wk.OpSwitch, @@ -1365,16 +1711,20 @@ impl LazyInst<'_, '_> { } } }; - spv::InstWithIds { without_ids: inst, result_type_id: None, result_id: None, ids } + each_inst(spv::InstWithIds { + without_ids: inst, + result_type_id: None, + result_id: None, + ids: id_operands, + }); } - Self::OpFunctionEnd => spv::InstWithIds { + Self::OpFunctionEnd => each_inst(spv::InstWithIds { without_ids: wk.OpFunctionEnd.into(), result_type_id: None, result_id: None, ids: [].into_iter().collect(), - }, - }; - (inst, attrs) + }), + } } } @@ -1417,35 +1767,38 @@ impl Module { // Collect uses scattered throughout the module, allocating IDs for them. let (ids, id_bound) = { let mut id_bound = NonZeroUsize::MIN; - let mut id_allocator = IdAllocator { + let mut lifter = Lifter { cx: &cx, module: self, - alloc_id: || { - let id = id_bound; - id_bound = - id_bound.checked_add(1).expect("overflowing `usize` should be impossible"); + alloc_ids: |count| { + let start = id_bound; + let end = + start.checked_add(count).expect("overflowing `usize` should be impossible"); + id_bound = end; // NOTE(eddyb) `MAX` is just a placeholder - the check for overflows // is done below, after all IDs that may be allocated, have been // (this is in order to not need this closure to return a `Result`). - id.try_into().unwrap_or(spv::Id::new(u32::MAX).unwrap()) + let from_usize = + |id| spv::Id::try_from(id).unwrap_or(spv::Id::new(u32::MAX).unwrap()); + from_usize(start)..from_usize(end) }, ids: ModuleIds::default(), data_inst_forms_seen: FxIndexSet::default(), global_vars_seen: FxIndexSet::default(), }; - id_allocator.visit_module(self); + lifter.visit_module(self); // See comment on `global_var_to_id_giving_global` for why this is here. - for &gv in &id_allocator.global_vars_seen { - id_allocator + for &gv in &lifter.global_vars_seen { + lifter .ids .globals .entry(global_var_to_id_giving_global(gv)) - .or_insert_with(&mut id_allocator.alloc_id); + .or_insert_with(|| lifter.alloc_ids.one()); } - let ids = id_allocator.ids; + let ids = lifter.ids; let id_bound = spv::Id::try_from(id_bound).ok().ok_or_else(|| { io::Error::new( @@ -1472,10 +1825,11 @@ impl Module { _ => unreachable!(), }; - let param_insts = - func_ids.param_ids.iter().zip(&func_decl.params).map(|(¶m_id, param)| { - LazyInst::OpFunctionParameter { param_id, param } - }); + let param_insts = func_ids + .param_ids + .iter() + .zip_eq(&func_decl.params) + .map(|(param_id, param)| LazyInst::OpFunctionParameter { param_id, param }); let body_insts = body_with_lifting.map(|(func_def_body, func_body_lifting)| { func_body_lifting.blocks.iter().flat_map(move |(point, block)| { let BlockLifting { phis, insts, terminator } = block; @@ -1496,11 +1850,9 @@ impl Module { let data_inst_def = func_at_inst.def(); LazyInst::DataInst { parent_func_ids: func_ids, - result_id: cx[data_inst_def.form].output_type.map(|_| { - func_body_lifting.data_inst_output_ids - [&func_at_inst.position] - }), data_inst_def, + data_inst_lifting: &func_body_lifting.data_insts + [func_at_inst.position], } }), ) @@ -1771,55 +2123,56 @@ impl Module { let mut current_debug_line = None; let mut current_block_id = None; // HACK(eddyb) for `current_debug_line` resets. for lazy_inst in global_and_func_insts { - let (inst, attrs) = lazy_inst.to_inst_and_attrs(self, ids); - - // Reset line debuginfo when crossing/leaving blocks. - let new_block_id = if inst.opcode == wk.OpLabel { - Some(inst.result_id.unwrap()) - } else if inst.opcode == wk.OpFunctionEnd { - None - } else { - current_block_id - }; - if current_block_id != new_block_id { - current_debug_line = None; - } - current_block_id = new_block_id; - - // Determine whether to emit `OpLine`/`OpNoLine` before `inst`, - // in order to end up with the expected line debuginfo. - // FIXME(eddyb) make this less of a search and more of a - // lookup by splitting attrs into key and value parts. - let new_debug_line = cx[attrs].attrs.iter().find_map(|attr| match *attr { - Attr::SpvDebugLine { file_path, line, col } => { - Some((ids.debug_strings[&cx[file_path.0]], line, col)) + let mut result: Result<(), _> = Ok(()); + lazy_inst.for_each_spv_inst_with_debug_line(self, ids, |inst, new_debug_line| { + if result.is_err() { + return; } - _ => None, - }); - if current_debug_line != new_debug_line { - let (opcode, imms, ids) = match new_debug_line { - Some((file_path_id, line, col)) => ( - wk.OpLine, - [ - spv::Imm::Short(wk.LiteralInteger, line), - spv::Imm::Short(wk.LiteralInteger, col), - ] - .into_iter() - .collect(), - iter::once(file_path_id).collect(), - ), - None => (wk.OpNoLine, [].into_iter().collect(), [].into_iter().collect()), + + // Reset line debuginfo when crossing/leaving blocks. + let new_block_id = if inst.opcode == wk.OpLabel { + Some(inst.result_id.unwrap()) + } else if inst.opcode == wk.OpFunctionEnd { + None + } else { + current_block_id }; - emitter.push_inst(&spv::InstWithIds { - without_ids: spv::Inst { opcode, imms }, - result_type_id: None, - result_id: None, - ids, - })?; - } - current_debug_line = new_debug_line; + if current_block_id != new_block_id { + current_debug_line = None; + } + current_block_id = new_block_id; + + // Determine whether to emit `OpLine`/`OpNoLine` before `inst`, + // in order to end up with the expected line debuginfo. + if current_debug_line != new_debug_line { + let (opcode, imms, ids) = match new_debug_line { + Some(SpvDebugLine { file_path_id, line, col }) => ( + wk.OpLine, + [ + spv::Imm::Short(wk.LiteralInteger, line), + spv::Imm::Short(wk.LiteralInteger, col), + ] + .into_iter() + .collect(), + iter::once(file_path_id).collect(), + ), + None => (wk.OpNoLine, [].into_iter().collect(), [].into_iter().collect()), + }; + result = emitter.push_inst(&spv::InstWithIds { + without_ids: spv::Inst { opcode, imms }, + result_type_id: None, + result_id: None, + ids, + }); + if result.is_err() { + return; + } + } + current_debug_line = new_debug_line; - emitter.push_inst(&inst)?; + result = emitter.push_inst(&inst); + }); + result?; } Ok(emitter) diff --git a/src/spv/lower.rs b/src/spv/lower.rs index f8b75b33..6d2c8575 100644 --- a/src/spv/lower.rs +++ b/src/spv/lower.rs @@ -5,14 +5,17 @@ use crate::spv::{self, spec}; use crate::{ cfg, print, scalar, AddrSpace, Attr, AttrSet, Const, ConstDef, ConstKind, Context, ControlNodeDef, ControlNodeKind, ControlRegion, ControlRegionDef, ControlRegionInputDecl, - DataInstDef, DataInstFormDef, DataInstKind, DeclDef, Diag, EntityDefs, EntityList, ExportKey, - Exportee, Func, FuncDecl, FuncDefBody, FuncParam, FxIndexMap, GlobalVarDecl, GlobalVarDefBody, - Import, InternedStr, Module, SelectionKind, Type, TypeDef, TypeKind, TypeOrConst, Value, + DataInst, DataInstDef, DataInstFormDef, DataInstKind, DeclDef, Diag, EntityDefs, EntityList, + ExportKey, Exportee, Func, FuncDecl, FuncDefBody, FuncParam, FxIndexMap, GlobalVarDecl, + GlobalVarDefBody, Import, InternedStr, Module, SelectionKind, Type, TypeDef, TypeKind, + TypeOrConst, Value, }; +use itertools::Either; use rustc_hash::FxHashMap; use smallvec::SmallVec; use std::collections::{BTreeMap, BTreeSet}; use std::num::NonZeroU32; +use std::ops::Range; use std::path::Path; use std::rc::Rc; use std::{io, mem}; @@ -22,6 +25,18 @@ enum IdDef { Type(Type), Const(Const), + /// Like `Const`, but for SPIR-V "aggregate" (`OpTypeStruct`/`OpTypeArray`) + /// constants (e.g. `OpConstantComposite`s of those types, but also more + /// general constants like `OpUndef`/`OpConstantNull` etc.). + AggregateConst { + // FIXME(eddyb) remove `whole_const` by always using the `leaves`. + whole_const: Const, + + whole_type: Type, + + leaves: SmallVec<[Const; 2]>, + }, + Func(Func), SpvExtInstImport(InternedStr), @@ -33,8 +48,10 @@ impl IdDef { match *self { // FIXME(eddyb) print these with some kind of "maximum depth", // instead of just describing the kind of definition. + // FIXME(eddyb) replace these with the `Diag` embedding system. IdDef::Type(_) => "a type".into(), IdDef::Const(_) => "a constant".into(), + IdDef::AggregateConst { .. } => "an aggregate constant".into(), IdDef::Func(_) => "a function".into(), @@ -46,6 +63,90 @@ impl IdDef { } } +impl Type { + fn aggregate_component_leaf_range_and_type( + self, + cx: &Context, + idx: u32, + ) -> Option<(Range, Type)> { + let (type_and_const_inputs, aggregate_shape) = match &cx[self].kind { + TypeKind::SpvInst { + spv_inst: _, + type_and_const_inputs, + value_lowering: spv::ValueLowering::Disaggregate(aggregate_shape), + } => (type_and_const_inputs, aggregate_shape), + _ => return None, + }; + let expect_type = |ty_or_ct| match ty_or_ct { + TypeOrConst::Type(ty) => ty, + TypeOrConst::Const(_) => unreachable!(), + }; + + let idx_usize = idx as usize; + let component_type = match aggregate_shape { + spv::AggregateShape::Struct { .. } => { + expect_type(*type_and_const_inputs.get(idx_usize)?) + } + &spv::AggregateShape::Array { fixed_len, .. } => { + if idx >= fixed_len { + return None; + } + expect_type(type_and_const_inputs[0]) + } + }; + let component_leaf_count = cx[component_type].disaggregated_leaf_count(); + + let component_leaf_range = match aggregate_shape { + spv::AggregateShape::Struct { per_field_leaf_range_end } => { + let end = per_field_leaf_range_end[idx_usize] as usize; + let start = end.checked_sub(component_leaf_count)?; + start..end + } + spv::AggregateShape::Array { .. } => { + let start = component_leaf_count.checked_mul(idx_usize)?; + let end = start.checked_add(component_leaf_count)?; + start..end + } + }; + Some((component_leaf_range, component_type)) + } + + // HACK(eddyb) `indices` is a `&mut` because it specifically only consumes + // the indices it needs, so when this function returns `Some`, all remaining + // indices will be left over for the caller to process itself. + fn aggregate_component_path_leaf_range_and_type( + self, + cx: &Context, + indices: &mut impl Iterator, + ) -> Option<(Range, Type)> { + let (mut leaf_range, mut leaf_type) = + self.aggregate_component_leaf_range_and_type(cx, indices.next()?)?; + + while let spv::ValueLowering::Disaggregate(_) = cx[leaf_type].spv_value_lowering() { + let (sub_leaf_range, sub_leaf_type) = match indices.next() { + Some(i) => leaf_type.aggregate_component_leaf_range_and_type(cx, i)?, + None => break, + }; + + assert!(sub_leaf_range.end <= leaf_range.len()); + leaf_range.end = leaf_range.start + sub_leaf_range.end; + leaf_range.start += sub_leaf_range.start; + leaf_type = sub_leaf_type; + } + + Some((leaf_range, leaf_type)) + } +} + +/// Error type for when a SPIR-V type cannot have a `spv::ValueLowering`, i.e. +/// this type can only be used behind a pointer. Disaggregation won't be +/// performed, so illegal attempts at constructing values of this type will +/// be kept intact, but annotated with an error [`Diag`]nostic. +// +// FIXME(eddyb) include an actual error message in here, maybe a whole `Diag`. +#[derive(Clone)] +struct TypeIsIndirectOnly; + /// Deferred export, needed because the IDs are initially forward refs. enum Export { Linkage { @@ -80,7 +181,7 @@ struct IntraFuncInst { ids: SmallVec<[spv::Id; 4]>, } -// FIXME(eddyb) stop abusing `io::Error` for error reporting. +// FIXME(eddyb) stop abusing `io::Error` for error reporting and switch to `Diag`. fn invalid(reason: &str) -> io::Error { io::Error::new(io::ErrorKind::InvalidData, format!("malformed SPIR-V ({reason})")) } @@ -102,9 +203,10 @@ fn invalid_factory_for_spv_inst( // FIXME(eddyb) provide more information about any normalization that happened: // * stats about deduplication that occured through interning // * sets of unused global vars and functions (and types+consts only they use) -// FIXME(eddyb) consider introducing a "deferred error" system, where `spv::lower` -// (and more directproducers) can keep around errors in the SPIR-T IR, and still -// have the opportunity of silencing them e.g. by removing dead code. +// FIXME(eddyb) use `Diag` instead of `io::Error`, maybe with a return type like +// `Result` where `IncompletelyLoweredModule` +// contains a `Module`, maps of all the SPIR-V IDs (to the SPIR-T definitions), +// global `Diag`s (where they can't be attached to specific `AttrSet`s), etc. impl Module { pub fn lower_from_spv_file(cx: Rc, path: impl AsRef) -> io::Result { Self::lower_from_spv_module_parser(cx, spv::read::ModuleParser::read_from_spv_file(path)?) @@ -563,6 +665,7 @@ impl Module { kind: TypeKind::SpvInst { spv_inst: spv::Inst { opcode, imms: [sc].into_iter().collect() }, type_and_const_inputs: [].into_iter().collect(), + value_lowering: Default::default(), }, }); id_defs.insert(id, IdDef::Type(ty)); @@ -589,9 +692,7 @@ impl Module { let ty = cx.intern(TypeDef { attrs: mem::take(&mut attrs), - kind: inst.as_canonical_type(&cx, &type_and_const_inputs).unwrap_or( - TypeKind::SpvInst { spv_inst: inst.without_ids, type_and_const_inputs }, - ), + kind: inst.without_ids.into_canonical_type_with(&cx, type_and_const_inputs), }); id_defs.insert(id, IdDef::Type(ty)); @@ -600,13 +701,68 @@ impl Module { || inst.always_lower_as_const() { let id = inst.result_id.unwrap(); + let ty = result_type.unwrap(); + let mut aggregate_leaves = match cx[ty].spv_value_lowering() { + spv::ValueLowering::Direct => None, + spv::ValueLowering::Disaggregate(_) => { + // HACK(eddyb) this expands `OpUndef`/`OpConstantNull`. + // FIXME(eddyb) this could potentially create a very + // inefficient large array, even when the intent can + // be expressed much more compactly in theory. + if inst.lower_const_by_distributing_to_aggregate_leaves() { + assert_eq!(inst.ids.len(), 0); + Some( + ty.disaggregated_leaf_types(&cx) + .map(|leaf_type| { + cx.intern(ConstDef { + attrs: Default::default(), + ty: leaf_type, + kind: inst + .as_canonical_const(&cx, leaf_type, &[]) + .unwrap_or_else(|| ConstKind::SpvInst { + spv_inst_and_const_inputs: Rc::new(( + inst.without_ids.clone(), + [].into_iter().collect(), + )), + }), + }) + }) + .collect(), + ) + } else if [wk.OpConstantComposite, wk.OpSpecConstantComposite] + .contains(&opcode) + { + // NOTE(eddyb) actual leaves gathered below, while + // collecting `const_inputs`. + Some(SmallVec::with_capacity(cx[ty].disaggregated_leaf_count())) + } else { + attrs.push_diag( + &cx, + Diag::bug(["unsupported aggregate-producing constant".into()]), + ); + None + } + } + }; + let const_inputs: SmallVec<_> = inst .ids .iter() .map(|&id| match id_defs.get(&id) { - Some(&IdDef::Const(ct)) => Ok(ct), + Some(&IdDef::Const(ct)) => { + if let Some(aggregate_leaves) = &mut aggregate_leaves { + aggregate_leaves.push(ct); + } + Ok(ct) + } + Some(IdDef::AggregateConst { whole_const, whole_type: _, leaves }) => { + if let Some(aggregate_leaves) = &mut aggregate_leaves { + aggregate_leaves.extend(leaves.iter().copied()); + } + Ok(*whole_const) + } Some(id_def) => Err(id_def.descr(&cx)), None => Err(format!("a forward reference to %{id}")), }) @@ -617,6 +773,23 @@ impl Module { }) .collect::>()?; + if let (spv::ValueLowering::Disaggregate(_), Some(leaves)) = + (cx[ty].spv_value_lowering(), &aggregate_leaves) + { + if cx[ty].disaggregated_leaf_count() != leaves.len() { + attrs.push_diag( + &cx, + Diag::err([format!( + "aggregate leaf count mismatch (expected {}, found {})", + cx[ty].disaggregated_leaf_count(), + leaves.len() + ) + .into()]), + ); + aggregate_leaves = None; + } + } + let ct = cx.intern(ConstDef { attrs: mem::take(&mut attrs), ty, @@ -626,7 +799,17 @@ impl Module { } }), }); - id_defs.insert(id, IdDef::Const(ct)); + id_defs.insert( + id, + match (cx[ty].spv_value_lowering(), aggregate_leaves) { + (spv::ValueLowering::Disaggregate(_), Some(leaves)) => { + // FIXME(eddyb) this may lose semantic `attrs` when + // `leaves` are directly used. + IdDef::AggregateConst { whole_const: ct, whole_type: ty, leaves } + } + _ => IdDef::Const(ct), + }, + ); if inst_category != spec::InstructionCategory::Const { // `OpUndef` can appear either among constants, or in a @@ -659,6 +842,10 @@ impl Module { let initializer = initializer .map(|id| match id_defs.get(&id) { Some(&IdDef::Const(ct)) => Ok(ct), + Some(&IdDef::AggregateConst { whole_const, .. }) => { + // FIXME(eddyb) disaggregate global initializers. + Ok(whole_const) + } Some(id_def) => Err(id_def.descr(&cx)), None => Err(format!("a forward reference to %{id}")), }) @@ -706,8 +893,6 @@ impl Module { } let func_id = inst.result_id.unwrap(); - // FIXME(eddyb) hide this from SPIR-T, it's the function return - // type, *not* the function type, which is in `func_type`. let func_ret_type = result_type.unwrap(); let func_type_id = match (&inst.imms[..], &inst.ids[..]) { @@ -720,7 +905,7 @@ impl Module { let (func_type_ret_type, func_type_param_types) = match id_defs.get(&func_type_id) { Some(&IdDef::Type(ty)) => match &cx[ty].kind { - TypeKind::SpvInst { spv_inst, type_and_const_inputs } + TypeKind::SpvInst { spv_inst, type_and_const_inputs, .. } if spv_inst.opcode == wk.OpTypeFunction => { let mut types = @@ -774,17 +959,30 @@ impl Module { } }; - let func = module.funcs.define( - &cx, - FuncDecl { - attrs: mem::take(&mut attrs), - ret_type: func_ret_type, - params: func_type_param_types - .map(|ty| FuncParam { attrs: AttrSet::default(), ty }) - .collect(), - def, - }, - ); + // Always flatten aggregates in param and return types. + let ret_types = match &cx[func_ret_type].kind { + // HACK(eddyb) `OpTypeVoid` special-cased here as if it were + // an aggregate with `0` leaves. + TypeKind::SpvInst { spv_inst: func_ret_type_spv_inst, .. } + if func_ret_type_spv_inst.opcode == wk.OpTypeVoid => + { + [].into_iter().collect() + } + + _ => func_ret_type.disaggregated_leaf_types(&cx).collect(), + }; + let mut params = SmallVec::with_capacity(func_type_param_types.len()); + for param_type in func_type_param_types { + params.extend( + param_type + .disaggregated_leaf_types(&cx) + .map(|ty| FuncParam { attrs: AttrSet::default(), ty }), + ); + } + + let func = module + .funcs + .define(&cx, FuncDecl { attrs: mem::take(&mut attrs), ret_types, params, def }); id_defs.insert(func_id, IdDef::Func(func)); current_func_body = Some(FuncBody { func_id, func, insts: vec![] }); @@ -855,6 +1053,9 @@ impl Module { struct PhiKey { source_block_id: spv::Id, target_block_id: spv::Id, + // FIXME(eddyb) remove this, key phis only by the edge, and keep + // a per-edge list of phi input `spv::Id`s (with validation for + // missing entries/duplicates). target_phi_idx: u32, } @@ -957,13 +1158,29 @@ impl Module { None }; - #[derive(Copy, Clone)] - enum LocalIdDef { - Value(Type, Value), + // HACK(eddyb) this is generic to allow `IdDef::AggregateConst`s + // to be converted to `LocalIdDef::Value`s, inside `lookup_id`. + enum LocalIdDef>> { + Value { whole_type: Type, leaves: VL }, BlockLabel(ControlRegion), } - let mut local_id_defs = FxIndexMap::default(); + struct ValueRange { + region_or_data_inst: Either, + range: Range, + } + + impl ValueRange { + fn iter(&self) -> impl ExactSizeIterator + Clone { + let region_or_data_inst = self.region_or_data_inst; + self.range.clone().map(move |i| match region_or_data_inst { + Either::Left(region) => Value::ControlRegionInput { region, input_idx: i }, + Either::Right(inst) => Value::DataInstOutput { inst, output_idx: i }, + }) + } + } + + let mut local_id_defs = FxIndexMap::::default(); // Labels can be forward-referenced, so always have them present. local_id_defs.extend( @@ -987,15 +1204,70 @@ impl Module { let invalid = invalid_factory_for_spv_inst(&raw_inst.without_ids, result_id, ids); - // FIXME(eddyb) find a more compact name and/or make this a method. + let is_last_in_block = lookahead_raw_inst(1) + .map_or(true, |next_raw_inst| next_raw_inst.without_ids.opcode == wk.OpLabel); + + // HACK(eddyb) this is handled early because it's the only case + // where a `result_id` isn't a value, and `OpFunctionParameter` + // wants to be able to use common value result helpers. + if opcode == wk.OpLabel { + if is_last_in_block { + return Err(invalid("block lacks terminator instruction")); + } + + // An empty `ControlRegion` was defined earlier, + // to be able to have an entry in `local_id_defs`. + let control_region = match local_id_defs[&result_id.unwrap()] { + LocalIdDef::BlockLabel(control_region) => control_region, + LocalIdDef::Value { .. } => unreachable!(), + }; + let current_block_details = &block_details[&control_region]; + assert_eq!(current_block_details.label_id, result_id.unwrap()); + current_block_control_region_and_details = + Some((control_region, current_block_details)); + continue; + } + // FIXME(eddyb) this returns `LocalIdDef` even for global values. - let lookup_global_or_local_id_for_data_or_control_inst_input = |id| match id_defs - .get(&id) - { - Some(&IdDef::Const(ct)) => Ok(LocalIdDef::Value(cx[ct].ty, Value::Const(ct))), + let lookup_id = |id| match id_defs.get(&id) { + None => { + let local_id_def = local_id_defs.get(&id).ok_or_else(|| { + // FIXME(eddyb) scan the rest of the function for any + // instructions returning this ID, to report an invalid + // forward reference (use before def). + invalid(&format!("undefined ID %{id}")) + })?; + // HACK(eddyb) change the type of `leaves` within + // `LocalIdDef::Value` to support consts + // (see `IdDef::AggregateConst` case just below). + Ok(match local_id_def { + LocalIdDef::Value { whole_type, leaves } => LocalIdDef::Value { + whole_type: *whole_type, + leaves: Either::Left( + leaves + .as_ref() + .map_left(|leaves| leaves.iter()) + .map_right(|leaves| leaves.iter().copied()), + ), + }, + &LocalIdDef::BlockLabel(label) => LocalIdDef::BlockLabel(label), + }) + } + Some(&IdDef::Const(ct)) => Ok(LocalIdDef::Value { + whole_type: cx[ct].ty, + leaves: Either::Right(Either::Left([Value::Const(ct)].into_iter())), + }), + Some(IdDef::AggregateConst { whole_const: _, whole_type, leaves }) => { + Ok(LocalIdDef::Value { + whole_type: *whole_type, + leaves: Either::Right(Either::Right( + leaves.iter().copied().map(Value::Const), + )), + }) + } Some(id_def @ IdDef::Type(_)) => Err(invalid(&format!( "unsupported use of {} as an operand for \ - an instruction in a function", + an instruction in a function", id_def.descr(&cx), ))), Some(id_def @ IdDef::Func(_)) => Err(invalid(&format!( @@ -1013,11 +1285,14 @@ impl Module { ty, kind: ConstKind::SpvStringLiteralForExtInst(*s), }); - Ok(LocalIdDef::Value(ty, Value::Const(ct))) + Ok(LocalIdDef::Value { + whole_type: ty, + leaves: Either::Right(Either::Left([Value::Const(ct)].into_iter())), + }) } else { Err(invalid(&format!( "unsupported use of {} outside `OpSource`, \ - `OpLine`, or `OpExtInst`", + `OpLine`, or `OpExtInst`", id_def.descr(&cx), ))) } @@ -1026,16 +1301,31 @@ impl Module { "unsupported use of {} outside `OpExtInst`", id_def.descr(&cx), ))), - // FIXME(eddyb) scan the rest of the function for any - // instructions returning this ID, to report an invalid - // forward reference (use before def). - None => local_id_defs - .get(&id) - .copied() - .ok_or_else(|| invalid(&format!("undefined ID %{id}",))), + }; + + // FIXME(eddyb) is this even necessary anymore? + let value_from_region_inputs_or_data_inst_outputs = + |region_or_data_inst, disaggregated_leaf_range| LocalIdDef::Value { + whole_type: result_type.unwrap(), + leaves: Either::Left(ValueRange { + region_or_data_inst, + range: disaggregated_leaf_range, + }), + }; + + // Helper shared by `OpFunctionParameter` and `OpPhi`. + let attrs_for_result_leaf = |leaf_type: Type| { + if result_type == Some(leaf_type) { + attrs + } else { + // FIXME(eddyb) this may lose semantic `attrs`. + AttrSet::default() + } }; if opcode == wk.OpFunctionParameter { + let result_type = result_type.unwrap(); + if current_block_control_region_and_details.is_some() { return Err(invalid( "out of order: `OpFunctionParameter`s should come \ @@ -1045,18 +1335,29 @@ impl Module { assert!(imms.is_empty() && ids.is_empty()); - let ty = result_type.unwrap(); - params.push(FuncParam { attrs, ty }); + let param_start = params.len(); + params.extend( + result_type + .disaggregated_leaf_types(&cx) + .map(|ty| FuncParam { attrs: attrs_for_result_leaf(ty), ty }), + ); + let param_end = params.len(); + if let Some(func_def_body) = &mut func_def_body { let body_inputs = &mut func_def_body.at_mut_body().def().inputs; - let input_idx = u32::try_from(body_inputs.len()).unwrap(); - body_inputs.push(ControlRegionInputDecl { attrs, ty }); + let start = u32::try_from(body_inputs.len()).unwrap(); + body_inputs.extend( + params[param_start..param_end].iter().map( + |&FuncParam { attrs, ty }| ControlRegionInputDecl { attrs, ty }, + ), + ); + let end = u32::try_from(body_inputs.len()).unwrap(); local_id_defs.insert( result_id.unwrap(), - LocalIdDef::Value( - ty, - Value::ControlRegionInput { region: func_def_body.body, input_idx }, + value_from_region_inputs_or_data_inst_outputs( + Either::Left(func_def_body.body), + start..end, ), ); } @@ -1064,26 +1365,6 @@ impl Module { } let func_def_body = func_def_body.as_deref_mut().unwrap(); - let is_last_in_block = lookahead_raw_inst(1) - .map_or(true, |next_raw_inst| next_raw_inst.without_ids.opcode == wk.OpLabel); - - if opcode == wk.OpLabel { - if is_last_in_block { - return Err(invalid("block lacks terminator instruction")); - } - - // An empty `ControlRegion` was defined earlier, - // to be able to have an entry in `local_id_defs`. - let control_region = match local_id_defs[&result_id.unwrap()] { - LocalIdDef::BlockLabel(control_region) => control_region, - LocalIdDef::Value(..) => unreachable!(), - }; - let current_block_details = &block_details[&control_region]; - assert_eq!(current_block_details.label_id, result_id.unwrap()); - current_block_control_region_and_details = - Some((control_region, current_block_details)); - continue; - } let (current_block_control_region, current_block_details) = current_block_control_region_and_details.ok_or_else(|| { invalid("out of order: not expected before the function's blocks") @@ -1102,23 +1383,6 @@ impl Module { } let mut target_inputs = FxIndexMap::default(); - let descr_phi_case = |phi_key: &PhiKey| { - format!( - "`OpPhi` (#{} in %{})'s case for source block %{}", - phi_key.target_phi_idx, - phi_key.target_block_id, - phi_key.source_block_id, - ) - }; - let phi_value_id_to_value = |phi_key: &PhiKey, id| { - match lookup_global_or_local_id_for_data_or_control_inst_input(id)? { - LocalIdDef::Value(_, v) => Ok(v), - LocalIdDef::BlockLabel { .. } => Err(invalid(&format!( - "unsupported use of block label as the value for {}", - descr_phi_case(phi_key) - ))), - } - }; let mut record_cfg_edge = |target_block| -> io::Result<()> { use indexmap::map::Entry; @@ -1134,28 +1398,55 @@ impl Module { Entry::Vacant(entry) => entry, }; - let inputs = (0..target_block_details.phi_count).map(|target_phi_idx| { + let mut target_inputs = SmallVec::new(); + for target_phi_idx in 0..target_block_details.phi_count { let phi_key = PhiKey { source_block_id: current_block_details.label_id, target_block_id: target_block_details.label_id, target_phi_idx, }; + let descr_phi_case = || { + format!( + "`OpPhi` (#{} in %{})'s case for source block %{}", + phi_key.target_phi_idx, + phi_key.target_block_id, + phi_key.source_block_id, + ) + }; + let phi_value_ids = phi_to_values.swap_remove(&phi_key).unwrap_or_default(); - match phi_value_ids[..] { - [] => Err(invalid(&format!( - "{} is missing", - descr_phi_case(&phi_key) - ))), - [id] => phi_value_id_to_value(&phi_key, id), - [..] => Err(invalid(&format!( - "{} is duplicated", - descr_phi_case(&phi_key) - ))), + let phi_value_id = match phi_value_ids[..] { + [] => { + return Err(invalid(&format!( + "{} is missing", + descr_phi_case() + ))); + } + [id] => id, + [..] => { + return Err(invalid(&format!( + "{} is duplicated", + descr_phi_case() + ))); + } + }; + + match lookup_id(phi_value_id)? { + LocalIdDef::Value { leaves, .. } => { + target_inputs.extend(leaves); + } + LocalIdDef::BlockLabel(_) => { + return Err(invalid(&format!( + "unsupported use of block label as the value for {}", + descr_phi_case() + ))); + } } - }); - target_inputs_entry.insert(inputs.collect::>()?); + } + + target_inputs_entry.insert(target_inputs); Ok(()) }; @@ -1166,16 +1457,33 @@ impl Module { let mut input_types = SmallVec::<[_; 2]>::new(); let mut targets = SmallVec::new(); for &id in ids { - match lookup_global_or_local_id_for_data_or_control_inst_input(id)? { - LocalIdDef::Value(ty, v) => { + match lookup_id(id)? { + LocalIdDef::Value { whole_type, leaves, .. } => { if !targets.is_empty() { return Err(invalid( "out of order: value operand \ after target label ID", )); } - inputs.push(v); - input_types.push(ty); + + match cx[whole_type].spv_value_lowering() { + spv::ValueLowering::Direct => {} + + // Returns are "lossily" disaggregated, just like + // function's signatures and calls to them. + spv::ValueLowering::Disaggregate(_) + if opcode == wk.OpReturnValue => {} + + spv::ValueLowering::Disaggregate(_) => { + return Err(invalid( + "unsupported aggregate value operand, \ + in non-return terminator instruction", + )); + } + } + + inputs.extend(leaves); + input_types.push(whole_type); } LocalIdDef::BlockLabel(target) => { record_cfg_edge(target)?; @@ -1189,7 +1497,7 @@ impl Module { assert!(targets.is_empty() && inputs.is_empty()); cfg::ControlInstKind::Unreachable } else if [wk.OpReturn, wk.OpReturnValue].contains(&opcode) { - assert!(targets.is_empty() && inputs.len() <= 1); + assert!(targets.is_empty()); cfg::ControlInstKind::Return } else if targets.is_empty() { cfg::ControlInstKind::ExitInvocation(cfg::ExitInvocationKind::SpvInst( @@ -1271,7 +1579,12 @@ impl Module { current_block_control_region, cfg::ControlInst { attrs, kind, inputs, targets, target_inputs }, ); - } else if opcode == wk.OpPhi { + continue; + } + + if opcode == wk.OpPhi { + let result_type = result_type.unwrap(); + if !current_block_control_region_def.children.is_empty() { return Err(invalid( "out of order: `OpPhi`s should come before \ @@ -1279,25 +1592,26 @@ impl Module { )); } - let ty = result_type.unwrap(); - - let input_idx = - u32::try_from(current_block_control_region_def.inputs.len()).unwrap(); - current_block_control_region_def - .inputs - .push(ControlRegionInputDecl { attrs, ty }); + let inputs = &mut current_block_control_region_def.inputs; + let start = u32::try_from(inputs.len()).unwrap(); + inputs.extend( + result_type.disaggregated_leaf_types(&cx).map(|ty| { + ControlRegionInputDecl { attrs: attrs_for_result_leaf(ty), ty } + }), + ); + let end = u32::try_from(inputs.len()).unwrap(); local_id_defs.insert( result_id.unwrap(), - LocalIdDef::Value( - ty, - Value::ControlRegionInput { - region: current_block_control_region, - input_idx, - }, + value_from_region_inputs_or_data_inst_outputs( + Either::Left(current_block_control_region), + start..end, ), ); - } else if [wk.OpSelectionMerge, wk.OpLoopMerge].contains(&opcode) { + continue; + } + + if [wk.OpSelectionMerge, wk.OpLoopMerge].contains(&opcode) { let is_second_to_last_in_block = lookahead_raw_inst(2) .map_or(true, |next_raw_inst| { next_raw_inst.without_ids.opcode == wk.OpLabel @@ -1314,12 +1628,10 @@ impl Module { // impact on the shape of a loop, for restructurization. if opcode == wk.OpLoopMerge { assert_eq!(ids.len(), 2); - let loop_merge_target = - match lookup_global_or_local_id_for_data_or_control_inst_input(ids[0])? - { - LocalIdDef::Value(..) => return Err(invalid("expected label ID")), - LocalIdDef::BlockLabel(target) => target, - }; + let loop_merge_target = match lookup_id(ids[0])? { + LocalIdDef::Value { .. } => return Err(invalid("expected label ID")), + LocalIdDef::BlockLabel(target) => target, + }; func_def_body .unstructured_cfg @@ -1333,107 +1645,15 @@ impl Module { // especially wrt the `SelectionControl` and `LoopControl` // operands, but it's not obvious how they should map to // some "structured regions" replacement for the CFG. - } else { - let mut ids = &ids[..]; - let kind = if let Some(kind) = raw_inst.without_ids.as_canonical_data_inst_kind( - &cx, - result_type.map(|ty| [ty]).as_ref().map_or(&[][..], |tys| &tys[..]), - ) { - // FIXME(eddyb) sanity-check the number/types of inputs. - kind - } else if opcode == wk.OpFunctionCall { - assert!(imms.is_empty()); - let callee_id = ids[0]; - let maybe_callee = id_defs - .get(&callee_id) - .map(|id_def| match *id_def { - IdDef::Func(func) => Ok(func), - _ => Err(id_def.descr(&cx)), - }) - .transpose() - .map_err(|descr| { - invalid(&format!( - "unsupported use of {descr} as the `OpFunctionCall` callee" - )) - })?; - - match maybe_callee { - Some(callee) => { - ids = &ids[1..]; - DataInstKind::FuncCall(callee) - } - - // HACK(eddyb) this should be an error, but it shows - // up in Rust-GPU output (likely a zombie?). - None => DataInstKind::SpvInst(raw_inst.without_ids.clone()), - } - } else if opcode == wk.OpExtInst { - let ext_set_id = ids[0]; - ids = &ids[1..]; - - let inst = match imms[..] { - [spv::Imm::Short(kind, inst)] => { - assert_eq!(kind, wk.LiteralExtInstInteger); - inst - } - _ => unreachable!(), - }; - - let ext_set = match id_defs.get(&ext_set_id) { - Some(&IdDef::SpvExtInstImport(name)) => Ok(name), - Some(id_def) => Err(id_def.descr(&cx)), - None => Err(format!("unknown ID %{ext_set_id}")), - } - .map_err(|descr| { - invalid(&format!( - "unsupported use of {descr} as the `OpExtInst` \ - extended instruction set ID" - )) - })?; + continue; + } - DataInstKind::SpvExtInst { ext_set, inst } - } else { - DataInstKind::SpvInst(raw_inst.without_ids.clone()) - }; + // All control-flow instructions have been handled above. + // Only `DataInst`s get generated below here. - let data_inst_def = DataInstDef { - attrs, - form: cx.intern(DataInstFormDef { - kind, - output_type: result_id - .map(|_| { - result_type.ok_or_else(|| { - invalid( - "expected value-producing instruction, \ - with a result type", - ) - }) - }) - .transpose()?, - }), - inputs: ids - .iter() - .map(|&id| { - match lookup_global_or_local_id_for_data_or_control_inst_input(id)? - { - LocalIdDef::Value(_, v) => Ok(v), - LocalIdDef::BlockLabel { .. } => Err(invalid( - "unsupported use of block label as a value, \ - in non-terminator instruction", - )), - } - }) - .collect::>()?, - }; + let mut append_data_inst = |data_inst_def: DataInstDef| { let inst = func_def_body.data_insts.define(&cx, data_inst_def.into()); - if let Some(result_id) = result_id { - local_id_defs.insert( - result_id, - LocalIdDef::Value(result_type.unwrap(), Value::DataInstOutput(inst)), - ); - } - let current_block_control_node = current_block_control_region_def .children .iter() @@ -1464,6 +1684,333 @@ impl Module { } _ => unreachable!(), } + + inst + }; + + let lookup_value_id = |id| match lookup_id(id)? { + LocalIdDef::Value { whole_type, leaves } => Ok((whole_type, leaves)), + LocalIdDef::BlockLabel(_) => Err(invalid( + "unsupported use of block label as a value, \ + in non-terminator instruction", + )), + }; + + // Special-case instructions which deal with aggregates as + // "containers" for their leaves, and so have an effect which + // can be interpreted eagerly on the disaggregated form. + // FIXME(eddyb) this may lose semantic `attrs` + let eagerly_lowered_result = if opcode == wk.OpCompositeConstruct { + let result_type = result_type.unwrap(); + + match cx[result_type].spv_value_lowering() { + spv::ValueLowering::Direct => None, + spv::ValueLowering::Disaggregate(_) => { + let mut all_leaves = + SmallVec::with_capacity(cx[result_type].disaggregated_leaf_count()); + for &id in ids { + let (_, leaves) = lookup_value_id(id)?; + all_leaves.extend(leaves); + } + if all_leaves.len() == cx[result_type].disaggregated_leaf_count() { + Some(LocalIdDef::Value { + whole_type: result_type, + leaves: Either::Right(all_leaves), + }) + } else { + None + } + } + } + } else if [wk.OpCompositeExtract, wk.OpCompositeInsert].contains(&opcode) { + let result_type = result_type.unwrap(); + + let (&composite_id, ids_without_last) = ids.split_last().unwrap(); + let (composite_type, leaves) = lookup_value_id(composite_id)?; + + // HACK(eddyb) `replace_component` and `rebuild_composite` + // are always both `None` or both `Some`, but splitting the + // two aspects of `OpCompositeInsert` makes it easier later. + let (component_type, replace_component, rebuild_composite); + match ids_without_last[..] { + [] => { + component_type = result_type; + replace_component = None; + rebuild_composite = None; + } + [replacement_component_id] => { + let (replacement_component_type, replacement_component_leaves) = + lookup_value_id(replacement_component_id)?; + + component_type = replacement_component_type; + replace_component = Some(replacement_component_leaves); + rebuild_composite = Some(result_type); + } + _ => unreachable!(), + } + + // HACK(eddyb) this is a `try {...}`-like use of a closure. + (|| { + if let Some(expected_type) = rebuild_composite { + if composite_type != expected_type { + return None; + } + } + + let mut imms = imms.iter(); + let (leaf_range, leaf_type) = match cx[composite_type].spv_value_lowering() + { + spv::ValueLowering::Direct => return None, + spv::ValueLowering::Disaggregate(_) => composite_type + .aggregate_component_path_leaf_range_and_type( + &cx, + &mut imms.by_ref().map(|&imm| match imm { + spv::Imm::Short(_, i) => i, + _ => unreachable!(), + }), + )?, + }; + let non_aggregate_indexing_imms = imms.as_slice(); + + if non_aggregate_indexing_imms.is_empty() && leaf_type != component_type { + return None; + } + + let mut component_leaves = + leaves.clone().skip(leaf_range.start).take(leaf_range.len()); + + // If there's any leftover indices they must be indexing + // into a vector/matrix, which requires separate handling. + let component_leaves = if !non_aggregate_indexing_imms.is_empty() { + assert_eq!(component_leaves.len(), 1); + let non_aggregate_composite = component_leaves.next().unwrap(); + + let leaf_spv_inst = spv::Inst { + opcode, + imms: non_aggregate_indexing_imms.iter().copied().collect(), + }; + let leaf_output_types = [match rebuild_composite { + Some(_) => leaf_type, + None => component_type, + }]; + let leaf_inst = append_data_inst(DataInstDef { + attrs, + form: cx.intern(DataInstFormDef { + kind: leaf_spv_inst + .as_canonical_data_inst_kind(&cx, &leaf_output_types) + .unwrap_or(DataInstKind::SpvInst( + leaf_spv_inst, + spv::InstLowering::default(), + )), + output_types: leaf_output_types.into_iter().collect(), + }), + inputs: replace_component + .map(|mut replacement_leaves| { + assert_eq!(replacement_leaves.len(), 1); + replacement_leaves.next().unwrap() + }) + .into_iter() + .chain([non_aggregate_composite]) + .collect(), + }); + Either::Left( + [Value::DataInstOutput { inst: leaf_inst, output_idx: 0 }] + .into_iter(), + ) + } else { + Either::Right( + replace_component + .map_or(Either::Left(component_leaves), Either::Right), + ) + }; + + assert_eq!( + component_leaves.len(), + cx[component_type].disaggregated_leaf_count() + ); + + let leaves = match rebuild_composite { + Some(_) => leaves + .clone() + .take(leaf_range.start) + .chain(component_leaves) + .chain(leaves.skip(leaf_range.end)) + .collect(), + None => component_leaves.collect(), + }; + + Some(LocalIdDef::Value { + whole_type: result_type, + // FIXME(eddyb) avoid allocating somehow, like + // try "recompressing" into a `ValueRange`, or + // preserving that form throughout above? + leaves: Either::Right(leaves), + }) + })() + } else { + None + }; + if let Some(def) = eagerly_lowered_result { + local_id_defs.insert(result_id.unwrap(), def); + continue; + } + + let mut ids = &ids[..]; + let mut kind = if opcode == wk.OpFunctionCall { + assert!(imms.is_empty()); + let callee_id = ids[0]; + let maybe_callee = id_defs + .get(&callee_id) + .map(|id_def| match *id_def { + IdDef::Func(func) => Ok(func), + _ => Err(id_def.descr(&cx)), + }) + .transpose() + .map_err(|descr| { + invalid(&format!( + "unsupported use of {descr} as the `OpFunctionCall` callee" + )) + })?; + + match maybe_callee { + Some(callee) => { + ids = &ids[1..]; + DataInstKind::FuncCall(callee) + } + + // HACK(eddyb) this should be an error, but it shows + // up in Rust-GPU output (likely a zombie?). + None => DataInstKind::SpvInst( + raw_inst.without_ids.clone(), + spv::InstLowering::default(), + ), + } + } else if opcode == wk.OpExtInst { + let ext_set_id = ids[0]; + ids = &ids[1..]; + + let inst = match imms[..] { + [spv::Imm::Short(kind, inst)] => { + assert_eq!(kind, wk.LiteralExtInstInteger); + inst + } + _ => unreachable!(), + }; + + let ext_set = match id_defs.get(&ext_set_id) { + Some(&IdDef::SpvExtInstImport(name)) => Ok(name), + Some(id_def) => Err(id_def.descr(&cx)), + None => Err(format!("unknown ID %{ext_set_id}")), + } + .map_err(|descr| { + invalid(&format!( + "unsupported use of {descr} as the `OpExtInst` \ + extended instruction set ID" + )) + })?; + + DataInstKind::SpvExtInst { + ext_set, + inst, + lowering: spv::InstLowering::default(), + } + } else { + DataInstKind::SpvInst( + raw_inst.without_ids.clone(), + spv::InstLowering::default(), + ) + }; + + // HACK(eddyb) only factored out due to `kind`'s mutable borrow. + let call_ret_type = match kind { + DataInstKind::FuncCall(_) => Some(result_type.unwrap()), + _ => None, + }; + + let mut spv_inst_lowering = match &mut kind { + DataInstKind::SpvInst(_, lowering) + | DataInstKind::SpvExtInst { lowering, .. } => Some(lowering), + + // NOTE(eddyb) function signatures and calls keep their + // disaggregation even when lifting back to SPIR-V, so + // no `spv::InstLowering` is tracked for them. + DataInstKind::FuncCall(_) => None, + + DataInstKind::Scalar(_) | DataInstKind::Vector(_) | DataInstKind::QPtr(_) => { + unreachable!() + } + }; + + let output_types: SmallVec<_> = result_id + .map(|_| { + let result_type = result_type.unwrap(); + if let Some(spv_inst_lowering) = &mut spv_inst_lowering { + spv_inst_lowering.disaggregated_output = + match cx[result_type].spv_value_lowering() { + spv::ValueLowering::Direct => None, + spv::ValueLowering::Disaggregate(_) => Some(result_type), + }; + } + + // HACK(eddyb) `OpTypeVoid` special-cased for calls + // as if it were an aggregate with `0` leaves. + let ret_void = call_ret_type.is_some_and(|ty| match &cx[ty].kind { + TypeKind::SpvInst { spv_inst: ret_type_spv_inst, .. } => { + ret_type_spv_inst.opcode == wk.OpTypeVoid + } + _ => false, + }); + if ret_void { + [].into_iter().collect() + } else { + result_type.disaggregated_leaf_types(&cx).collect() + } + }) + .unwrap_or_default(); + let output_types_len = output_types.len(); + + let mut inputs = SmallVec::with_capacity(ids.len()); + for &id in ids { + let (whole_input_type, leaves) = lookup_value_id(id)?; + + let start = u32::try_from(inputs.len()).unwrap(); + inputs.extend(leaves); + let end = u32::try_from(inputs.len()).unwrap(); + + if let spv::ValueLowering::Disaggregate(_) = + cx[whole_input_type].spv_value_lowering() + { + if let Some(lowering) = &mut spv_inst_lowering { + lowering.disaggregated_inputs.push((start..end, whole_input_type)); + } + } + } + + if let DataInstKind::SpvInst(spv_inst, lowering) = &kind { + if lowering.disaggregated_inputs.is_empty() { + if let Some(canonical_kind) = + spv_inst.as_canonical_data_inst_kind(&cx, &output_types) + { + // FIXME(eddyb) sanity-check the number/types of inputs. + kind = canonical_kind; + } + } + } + + let inst = append_data_inst(DataInstDef { + attrs, + form: cx.intern(DataInstFormDef { kind, output_types }), + inputs, + }); + + if let Some(result_id) = result_id { + local_id_defs.insert( + result_id, + value_from_region_inputs_or_data_inst_outputs( + Either::Right(inst), + 0..u32::try_from(output_types_len).unwrap(), + ), + ); } } diff --git a/src/spv/mod.rs b/src/spv/mod.rs index 09728c1a..2b0156c5 100644 --- a/src/spv/mod.rs +++ b/src/spv/mod.rs @@ -10,11 +10,12 @@ pub mod read; pub mod spec; pub mod write; -use crate::{FxIndexMap, InternedStr}; +use crate::{Context, FxIndexMap, InternedStr, Type, TypeDef, TypeKind, TypeOrConst}; use smallvec::SmallVec; use std::collections::{BTreeMap, BTreeSet}; use std::iter; use std::num::NonZeroU32; +use std::ops::Range; use std::string::FromUtf8Error; /// Semantic properties of a SPIR-V module (not tied to any IDs). @@ -51,6 +52,316 @@ pub struct DebugSources { pub file_contents: FxIndexMap, } +/// Most SPIR-V types can be used as SPIR-V (SSA) value types, but some require +/// non-trivial lowering into SPIR-T [`Value`](crate::Value)s (e.g. expanding +/// one SPIR-V value into any number of *valid* SPIR-T values). +// +// FIXME(eddyb) aggregates without known leaf counts using `Direct` is worse than +// treating them as an error (and e.g. generating `Diag`s), but it's also simpler. +#[derive(Clone, Default, PartialEq, Eq, Hash)] +pub enum ValueLowering { + /// SPIR-V values of this type map to SPIR-T [`Value`](crate::Value)s with the same type + /// (see [`Value`](crate::Value) documentation for more details, and valid types). + #[default] + Direct, + + /// SPIR-V values of this type can't be kept intract in SPIR-T, but instead + /// require decomposion into their "leaves", i.e. valid SPIR-T [`Value`](crate::Value)s. + Disaggregate(AggregateShape), +} + +#[derive(Clone, PartialEq, Eq, Hash)] +pub enum AggregateShape { + Struct { per_field_leaf_range_end: SmallVec<[u32; 4]> }, + Array { fixed_len: u32, total_leaf_count: u32 }, +} + +impl AggregateShape { + // FIXME(eddyb) force this to be used via some kind of forced canonicalization. + pub fn compute( + cx: &Context, + spv_inst: &Inst, + type_and_const_inputs: &[TypeOrConst], + ) -> Option { + let wk = &spec::Spec::get().well_known; + + if spv_inst.opcode == wk.OpTypeStruct { + let mut leaf_count = 0u32; + let mut per_field_leaf_range_end = SmallVec::new(); + for &ty_or_ct in type_and_const_inputs { + let field_type = match ty_or_ct { + TypeOrConst::Type(ty) => ty, + TypeOrConst::Const(_) => return None, + }; + let field_leaf_count = cx[field_type].disaggregated_leaf_count_u32(); + + leaf_count = leaf_count.checked_add(field_leaf_count)?; + per_field_leaf_range_end.push(leaf_count); + } + Some(Self::Struct { per_field_leaf_range_end }) + } else if spv_inst.opcode == wk.OpTypeArray { + let (elem_type, len) = match type_and_const_inputs[..] { + [TypeOrConst::Type(elem_type), TypeOrConst::Const(len)] => (elem_type, len), + _ => return None, + }; + + // NOTE(eddyb) this can legally be `None` when the length of + // the array is given by a specialization constant. + let fixed_len = len.as_scalar(cx).and_then(|len| len.int_as_u32()); + let fixed_len = fixed_len?; + + let elem_leaf_count = cx[elem_type].disaggregated_leaf_count_u32(); + + Some(Self::Array { + fixed_len, + total_leaf_count: elem_leaf_count.checked_mul(fixed_len)?, + }) + } else { + None + } + } +} + +// FIXME(eddyb) not the best place to put these utilities, but they're used in +// both `spv::lower` and `spv::lift` (and they use private methods defined here). +// FIXME(eddyb) consider moving some of this to `spv::canonical`. +impl TypeDef { + fn spv_value_lowering(&self) -> &ValueLowering { + match &self.kind { + TypeKind::Scalar(_) + | TypeKind::Vector(_) + | TypeKind::QPtr + | TypeKind::SpvStringLiteralForExtInst => &ValueLowering::Direct, + TypeKind::SpvInst { value_lowering, .. } => value_lowering, + } + } + + fn disaggregated_leaf_count(&self) -> usize { + self.disaggregated_leaf_count_u32() as usize + } + + fn disaggregated_leaf_count_u32(&self) -> u32 { + match self.spv_value_lowering() { + ValueLowering::Direct => 1, + ValueLowering::Disaggregate(AggregateShape::Struct { per_field_leaf_range_end }) => { + per_field_leaf_range_end.last().copied().unwrap_or(0) + } + &ValueLowering::Disaggregate(AggregateShape::Array { total_leaf_count, .. }) => { + total_leaf_count + } + } + } +} + +/// Tree-like (preorder) traversal tool for [`ValueLowering::Disaggregate`] types. +struct AggregateCursor<'a> { + cx: &'a Context, + // FIXME(eddyb) should this cache any references into `&Context`? + current: Type, + parent_component_path: SmallVec<[(Type, u32); 8]>, +} + +impl AggregateCursor<'_> { + // HACK(eddyb) this returns `true` iff a new node was found. + fn try_advance(&mut self) -> bool { + // FIXME(eddyb) this isn't the best organization possible. + let cx = self.cx; + let get_component = move |ty: Type, idx: u32| -> Option { + let ty_def = &cx[ty]; + let type_input_idx = match ty_def.spv_value_lowering() { + ValueLowering::Direct => return None, + ValueLowering::Disaggregate(AggregateShape::Struct { .. }) => idx, + &ValueLowering::Disaggregate(AggregateShape::Array { fixed_len, .. }) => { + if idx >= fixed_len { + return None; + } + 0 + } + }; + let type_and_const_inputs = match &ty_def.kind { + TypeKind::Scalar(_) + | TypeKind::Vector(_) + | TypeKind::QPtr + | TypeKind::SpvStringLiteralForExtInst => &[][..], + TypeKind::SpvInst { type_and_const_inputs, .. } => &type_and_const_inputs[..], + }; + let expect_type = |ty_or_ct| match ty_or_ct { + TypeOrConst::Type(ty) => ty, + TypeOrConst::Const(_) => unreachable!(), + }; + Some(expect_type(*type_and_const_inputs.get(usize::try_from(type_input_idx).ok()?)?)) + }; + + // Try descending first, into the first child. + if let Some(first_child_type) = get_component(self.current, 0) { + self.parent_component_path.push((self.current, 0)); + self.current = first_child_type; + return true; + } + + // Try ascending until there is a next sibling to descend into, but only + // modifying `self` iff any such node is found, otherwise calling this + // method without checking its success, could result in infinite cycles. + for depth in (0..self.parent_component_path.len()).rev() { + let (ancestor_type, ancestor_child_idx) = &mut self.parent_component_path[depth]; + if let Some(sibling_idx) = ancestor_child_idx.checked_add(1) { + if let Some(sibling_type) = get_component(*ancestor_type, sibling_idx) { + *ancestor_child_idx = sibling_idx; + self.current = sibling_type; + self.parent_component_path.truncate(depth + 1); + return true; + } + } + } + + false + } + + // FIXME(eddyb) can't find a great name for this - crucially, if this is a + // leaf, it's a noop, it doesn't find the next leaf! + // HACK(eddyb) this returns `true` iff a leaf node was found. + fn try_ensure_at_leaf(&mut self) -> bool { + loop { + if let ValueLowering::Direct = self.cx[self.current].spv_value_lowering() { + return true; + } + if !self.try_advance() { + return false; + } + } + } +} + +/// Recursively flattening iterator for [`ValueLowering::Disaggregate`] types. +struct DisaggregatedLeafTypes<'a>(Option>); + +impl Iterator for DisaggregatedLeafTypes<'_> { + type Item = Type; + + fn size_hint(&self) -> (usize, Option) { + // HACK(eddyb) only compute a size hint for the fresh iterator. + if let Self(Some(cursor)) = self { + if cursor.parent_component_path.is_empty() { + let leaf_count = cursor.cx[cursor.current].disaggregated_leaf_count(); + return (leaf_count, Some(leaf_count)); + } + } + (0, None) + } + + fn next(&mut self) -> Option { + let cursor = self.0.as_mut()?; + let next = cursor.try_ensure_at_leaf().then_some(cursor.current); + + // Record advancement failure, ensuring future calls to `next` return `None`. + if !(next.is_some() && cursor.try_advance()) { + *self = Self(None); + } + + next + } +} + +// HACK(eddyb) `impl Trait` helper for when a non-bound lifetime needs capturing. +// FIXME(eddyb) should `map_with_parent_component_path` even use `impl Trait`? +trait Captures<'a> {} +impl Captures<'_> for T {} + +impl<'a> DisaggregatedLeafTypes<'a> { + // HACK(eddyb) like `map`, but acessing the `parent_component_path` of + // the inner `AggregateCursor`, when it is possitioned at each leaf. + fn map_with_parent_component_path( + self, + mut f: impl FnMut(Type, &[(Type, u32)]) -> T, + ) -> impl Iterator + Captures<'a> { + let Self(mut cursor_slot) = self; + iter::from_fn(move || { + let cursor = cursor_slot.as_mut()?; + let next = cursor + .try_ensure_at_leaf() + .then(|| f(cursor.current, &cursor.parent_component_path)); + + // Record advancement failure, ensuring future calls to `next` return `None`. + if !(next.is_some() && cursor.try_advance()) { + cursor_slot = None; + } + + next + }) + } +} + +// FIXME(eddyb) not the best place to put these utilities, but they're used in +// both `spv::lower` and `spv::lift` (and they use private methods defined here). +// FIXME(eddyb) consider moving some of this to `spv::canonical`. +impl Type { + fn disaggregated_leaf_types(self, cx: &Context) -> DisaggregatedLeafTypes<'_> { + DisaggregatedLeafTypes(Some(AggregateCursor { + cx, + current: self, + parent_component_path: SmallVec::new(), + })) + } +} + +/// Aspects of how a [`spv::Inst`](Inst) was produced by [`spv::lower`](lower), +/// which were otherwise lost in the SPIR-T form, but are still required for +/// [`spv::lift`](lift) to reproduce the original SPIR-V instruction. +/// +/// Primarily used within [`DataInstKind`](crate::DataInstKind) due to SPIR-V +/// instructions that take or produce "aggregates" (`OpTypeStruct`/`OpTypeArray`) +/// and which may require the exact original types (i.e. may not be valid when +/// using a fresh `OpTypeStruct` of the flattened non-"aggregate" components). +#[derive(Clone, Default, PartialEq, Eq, Hash)] +pub struct InstLowering { + // FIXME(eddyb) should this be named "result" instead of "output", somewhat + // standardizing the idea that 1 SPIR-V "result" maps to N SPIR-T "outputs"? + pub disaggregated_output: Option, + + // FIXME(eddyb) only store the starts, and get the leaf counts from the `Type`. + pub disaggregated_inputs: SmallVec<[(Range, Type); 1]>, +} + +/// Helper type for [`InstLowering::reaggreate_inputs`], which corresponds to +/// one or more inputs of a SPIR-V instruction (after being lowered to SPIR-T), +/// according to the [`InstLowering`] (and its `disaggregated_inputs` field). +#[derive(Copy, Clone)] +pub enum ReaggregatedIdOperand<'a, T> { + Direct(T), + Aggregate { ty: Type, leaves: &'a [T] }, +} + +impl InstLowering { + pub fn reaggreate_inputs<'a, T: Copy>( + &'a self, + inputs: &'a [T], + ) -> impl Iterator> + Clone { + // HACK(eddyb) the `None` at the end handles remainining direct inputs. + let mut prev_end = 0; + self.disaggregated_inputs.iter().map(Some).chain([None]).flat_map( + move |maybe_disaggregated| { + // FIXME(eddyb) the range manipulation is all over the place here. + let direct_range = prev_end + ..maybe_disaggregated.map_or(inputs.len(), |(range, _)| range.start as usize); + assert!(direct_range.start <= direct_range.end); + prev_end = direct_range.end; + + let direct_inputs = + inputs[direct_range].iter().copied().map(ReaggregatedIdOperand::Direct); + + let aggregate_input = maybe_disaggregated.map(|(range, ty)| { + let leaves_range = range.start as usize..range.end as usize; + prev_end = leaves_range.end; + + ReaggregatedIdOperand::Aggregate { ty: *ty, leaves: &inputs[leaves_range] } + }); + + direct_inputs.chain(aggregate_input) + }, + ) + } +} + /// A SPIR-V instruction, in its minimal form (opcode and immediate operands). #[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct Inst { diff --git a/src/spv/spec.rs b/src/spv/spec.rs index 81ddb800..0c97285c 100644 --- a/src/spv/spec.rs +++ b/src/spv/spec.rs @@ -131,6 +131,7 @@ def_well_known! { // FIXME(eddyb) hide these from code, lowering should handle most cases. OpConstantComposite, + OpSpecConstantComposite, OpVariable, diff --git a/src/transform.rs b/src/transform.rs index 844751a0..1b5b0a52 100644 --- a/src/transform.rs +++ b/src/transform.rs @@ -429,7 +429,7 @@ impl InnerTransform for TypeDef { | TypeKind::QPtr | TypeKind::SpvStringLiteralForExtInst => Transformed::Unchanged, - TypeKind::SpvInst { spv_inst, type_and_const_inputs } => Transformed::map_iter( + TypeKind::SpvInst { spv_inst, type_and_const_inputs, value_lowering } => Transformed::map_iter( type_and_const_inputs.iter(), |ty_or_ct| match *ty_or_ct { TypeOrConst::Type(ty) => transform!({ @@ -443,6 +443,7 @@ impl InnerTransform for TypeDef { ).map(|new_iter| TypeKind::SpvInst { spv_inst: spv_inst.clone(), type_and_const_inputs: new_iter.collect(), + value_lowering: value_lowering.clone(), }), }, } => Self { @@ -532,10 +533,12 @@ impl InnerInPlaceTransform for GlobalVarDefBody { impl InnerInPlaceTransform for FuncDecl { fn inner_in_place_transform_with(&mut self, transformer: &mut impl Transformer) { - let Self { attrs, ret_type, params, def } = self; + let Self { attrs, ret_types, params, def } = self; transformer.transform_attr_set_use(*attrs).apply_to(attrs); - transformer.transform_type_use(*ret_type).apply_to(ret_type); + for ty in ret_types { + transformer.transform_type_use(*ty).apply_to(ty); + } for param in params { param.inner_transform_with(transformer).apply_to(param); } @@ -708,11 +711,15 @@ impl InnerInPlaceTransform for FuncAtMut<'_, DataInst> { impl InnerTransform for DataInstFormDef { fn inner_transform_with(&self, transformer: &mut impl Transformer) -> Transformed { - let Self { kind, output_type } = self; + let Self { kind, output_types } = self; transform!({ kind -> match kind { - DataInstKind::FuncCall(func) => transformer.transform_func_use(*func).map(DataInstKind::FuncCall), + DataInstKind::Scalar(_) + | DataInstKind::Vector(_) => Transformed::Unchanged, + DataInstKind::FuncCall(func) => transform!({ + func -> transformer.transform_func_use(*func) + } => DataInstKind::FuncCall(func)), DataInstKind::QPtr(op) => match op { QPtrOp::FuncLocalVar(_) | QPtrOp::HandleArrayIndex @@ -723,22 +730,39 @@ impl InnerTransform for DataInstFormDef { | QPtrOp::Load {..} | QPtrOp::Store {..} => Transformed::Unchanged, }, - DataInstKind::Scalar(_) - | DataInstKind::Vector(_) - | DataInstKind::SpvInst(_) - | DataInstKind::SpvExtInst { .. } => Transformed::Unchanged, + DataInstKind::SpvInst(spv_inst, lowering) => transform!({ + lowering -> lowering.inner_transform_with(transformer) + } => DataInstKind::SpvInst(spv_inst.clone(), lowering)), + DataInstKind::SpvExtInst { ext_set, inst, lowering } => transform!({ + lowering -> lowering.inner_transform_with(transformer) + } => DataInstKind::SpvExtInst { ext_set: *ext_set, inst: *inst, lowering }), }, - // FIXME(eddyb) this should be replaced with an impl of `InnerTransform` - // for `Option` or some other helper, to avoid "manual transpose". - output_type -> output_type.map(|ty| transformer.transform_type_use(ty)) - .map_or(Transformed::Unchanged, |t| t.map(Some)), + output_types -> Transformed::map_iter(output_types.iter(), |&ty| transformer.transform_type_use(ty)) + .map(|new_iter| new_iter.collect()), } => Self { kind, - output_type, + output_types, }) } } +impl InnerTransform for spv::InstLowering { + fn inner_transform_with(&self, transformer: &mut impl Transformer) -> Transformed { + let Self { disaggregated_output, disaggregated_inputs } = self; + + transform!({ + // FIXME(eddyb) this should be replaced with an impl of `InnerTransform` + // for `Option` or some other helper, to avoid "manual transpose". + disaggregated_output -> disaggregated_output.map(|ty| transformer.transform_type_use(ty)) + .map_or(Transformed::Unchanged, |t| t.map(Some)), + disaggregated_inputs -> Transformed::map_iter( + disaggregated_inputs.iter(), + |(range, ty)| transformer.transform_type_use(*ty).map(|ty| (range.clone(), ty)) + ).map(|new_iter| new_iter.collect()), + } => Self { disaggregated_output, disaggregated_inputs }) + } +} + impl InnerInPlaceTransform for cfg::ControlInst { fn inner_in_place_transform_with(&mut self, transformer: &mut impl Transformer) { let Self { attrs, kind, inputs, targets: _, target_inputs } = self; @@ -773,7 +797,7 @@ impl InnerTransform for Value { Self::ControlRegionInput { region: _, input_idx: _ } | Self::ControlNodeOutput { control_node: _, output_idx: _ } - | Self::DataInstOutput(_) => Transformed::Unchanged, + | Self::DataInstOutput { inst: _, output_idx: _ } => Transformed::Unchanged, } } } diff --git a/src/visit.rs b/src/visit.rs index 4cbde0b4..d75b00c5 100644 --- a/src/visit.rs +++ b/src/visit.rs @@ -320,7 +320,7 @@ impl InnerVisit for TypeDef { | TypeKind::QPtr | TypeKind::SpvStringLiteralForExtInst => {} - TypeKind::SpvInst { spv_inst: _, type_and_const_inputs } => { + TypeKind::SpvInst { spv_inst: _, type_and_const_inputs, value_lowering: _ } => { for &ty_or_ct in type_and_const_inputs { match ty_or_ct { TypeOrConst::Type(ty) => visitor.visit_type_use(ty), @@ -396,10 +396,12 @@ impl InnerVisit for GlobalVarDefBody { impl InnerVisit for FuncDecl { fn inner_visit_with<'a>(&'a self, visitor: &mut impl Visitor<'a>) { - let Self { attrs, ret_type, params, def } = self; + let Self { attrs, ret_types, params, def } = self; visitor.visit_attr_set_use(*attrs); - visitor.visit_type_use(*ret_type); + for &ty in ret_types { + visitor.visit_type_use(ty); + } for param in params { param.inner_visit_with(visitor); } @@ -527,9 +529,10 @@ impl InnerVisit for DataInstDef { impl InnerVisit for DataInstFormDef { fn inner_visit_with<'a>(&'a self, visitor: &mut impl Visitor<'a>) { - let Self { kind, output_type } = self; + let Self { kind, output_types } = self; match kind { + DataInstKind::Scalar(_) | DataInstKind::Vector(_) => {} &DataInstKind::FuncCall(func) => visitor.visit_func_use(func), DataInstKind::QPtr(op) => match *op { QPtrOp::FuncLocalVar(_) @@ -541,12 +544,25 @@ impl InnerVisit for DataInstFormDef { | QPtrOp::Load { .. } | QPtrOp::Store { .. } => {} }, - DataInstKind::Scalar(_) - | DataInstKind::Vector(_) - | DataInstKind::SpvInst(_) - | DataInstKind::SpvExtInst { .. } => {} + DataInstKind::SpvInst(_, lowering) + | DataInstKind::SpvExtInst { ext_set: _, inst: _, lowering } => { + lowering.inner_visit_with(visitor); + } + } + for &ty in output_types { + visitor.visit_type_use(ty); + } + } +} + +impl InnerVisit for spv::InstLowering { + fn inner_visit_with<'a>(&'a self, visitor: &mut impl Visitor<'a>) { + let Self { disaggregated_output, disaggregated_inputs } = self; + + if let Some(ty) = *disaggregated_output { + visitor.visit_type_use(ty); } - if let Some(ty) = *output_type { + for &(_, ty) in disaggregated_inputs { visitor.visit_type_use(ty); } } @@ -583,7 +599,7 @@ impl InnerVisit for Value { Self::Const(ct) => visitor.visit_const_use(ct), Self::ControlRegionInput { region: _, input_idx: _ } | Self::ControlNodeOutput { control_node: _, output_idx: _ } - | Self::DataInstOutput(_) => {} + | Self::DataInstOutput { inst: _, output_idx: _ } => {} } } } From fcc57bdb387ed0e33e6d1177ea65fe2033d3abb3 Mon Sep 17 00:00:00 2001 From: Eduard-Mihai Burtescu Date: Sun, 1 Oct 2023 17:10:42 +0300 Subject: [PATCH 15/22] qptr/lower: expand aggregate `Op{Load,Store}` into leaf loads/stores. --- src/qptr/layout.rs | 74 +++++++++ src/qptr/lower.rs | 375 +++++++++++++++++++++++++++++++++++++-------- src/spv/spec.rs | 1 + 3 files changed, 384 insertions(+), 66 deletions(-) diff --git a/src/qptr/layout.rs b/src/qptr/layout.rs index 8d550428..559c9bae 100644 --- a/src/qptr/layout.rs +++ b/src/qptr/layout.rs @@ -99,6 +99,80 @@ pub(super) enum Components { }, } +impl MemTypeLayout { + /// Recursively expand `MemTypeLayout`s into their components, at every level + /// for which `predicate` returns `true`. `each_leaf` is called for each + /// leaf (scalar or `recurse_into` returned `false`) component, and includes + /// its offset (starting at `base_offset`). + /// + /// Because each array element has its own offset, each array element will + /// be separately flattened, such that the entire array will be covered. + /// + /// `Err` may be returned in some cases (e.g. offset overflows, dynamic arrays), + /// in which case the sequence of leaves `each_leaf` produced can be considered + /// incomplete and shouldn't be used. + pub(super) fn deeply_flatten_if( + &self, + base_offset: i32, + recurse_into: &impl Fn(&Self) -> bool, + each_leaf: &mut impl FnMut(i32, &Self) -> Result<(), LayoutError>, + ) -> Result<(), LayoutError> { + match &self.components { + Components::Scalar => each_leaf(base_offset, self), + _ if !recurse_into(self) => each_leaf(base_offset, self), + + Components::Elements { stride, elem, fixed_len } => { + let len = fixed_len.ok_or_else(|| { + LayoutError(Diag::err([ + "dynamically sized type `".into(), + self.original_type.into(), + "` cannot be flattened into a finite sequence of leaves".into(), + ])) + })?; + + for i in 0..len.get() { + let offset = i32::try_from(i) + .ok() + .and_then(|i| { + // HACK(eddyb) don't claim an overflow for `0 * stride` + // even if `stride` doesn't fit in `i32`. + if i == 0 { + Some(base_offset) + } else { + let stride = i32::try_from(stride.get()).ok()?; + base_offset.checked_add(i.checked_mul(stride)?) + } + }) + .ok_or_else(|| { + LayoutError(Diag::bug([format!( + "`{base_offset} + {i} * {stride}` overflowed `s32`" + ) + .into()])) + })?; + elem.deeply_flatten_if(offset, recurse_into, each_leaf)?; + } + Ok(()) + } + + Components::Fields { offsets, layouts } => { + for (&field_offset, field) in offsets.iter().zip(layouts) { + let offset = i32::try_from(field_offset) + .ok() + .and_then(|field_offset| base_offset.checked_add(field_offset)) + .ok_or_else(|| { + LayoutError(Diag::bug([format!( + "`{base_offset} + {field_offset}` overflowed `s32`" + ) + .into()])) + })?; + field.deeply_flatten_if(offset, recurse_into, each_leaf)?; + } + Ok(()) + } + } + } +} + impl Components { /// Return all components (by index), which completely contain `offset_range`. /// diff --git a/src/qptr/lower.rs b/src/qptr/lower.rs index 23c39559..c0bba605 100644 --- a/src/qptr/lower.rs +++ b/src/qptr/lower.rs @@ -12,6 +12,7 @@ use crate::{ EntityOrientedDenseMap, FuncDecl, GlobalVarDecl, OrdAssertEq, Type, TypeKind, TypeOrConst, Value, }; +use itertools::Either; use rustc_hash::FxHashMap; use smallvec::SmallVec; use std::cell::Cell; @@ -131,6 +132,10 @@ impl<'a> LowerFromSpvPtrs<'a> { // HACK(eddyb) this should handle shallow `QPtr` in the initializer, but // typed initializers should be replaced with miri/linker-style ones. + // FIXME(eddyb) this is even worse now, with disaggregation, + // the initializer should be disaggregated leaves, which then + // need to flattened into a miri-like representation, or at least + // have offsets assigned to each leaf (for `qptr::lift` to use). EraseSpvPtrs { lowerer: self }.in_place_transform_global_var_decl(global_var_decl); } Err(LowerError(e)) => { @@ -149,6 +154,7 @@ impl<'a> LowerFromSpvPtrs<'a> { data_inst_use_counts: Default::default(), remove_if_dead_inst_and_parent_block: Default::default(), noop_offsets_to_base_ptr: Default::default(), + aggregate_load_to_leaf_loads: Default::default(), } .in_place_transform_func_decl(func_decl); EraseSpvPtrs { lowerer: self }.in_place_transform_func_decl(func_decl); @@ -263,6 +269,9 @@ struct LowerFromSpvPtrInstsInFunc<'a> { // because it needs to be available from `transform_value`, which doesn't // have access to a `FuncAt` to look up anything. noop_offsets_to_base_ptr: FxHashMap, + + // HACK(eddyb) perhaps this should be a generalized "replace all uses of"? + aggregate_load_to_leaf_loads: FxHashMap>, } /// One `QPtr`->`QPtr` step used in the lowering of `Op*AccessChain`. @@ -456,15 +465,30 @@ impl LowerFromSpvPtrInstsInFunc<'_> { (ptr, offset) }; - let replacement_kind_and_inputs = if spv_inst.opcode == wk.OpVariable { - assert!(!disaggregated_output_or_inputs_during_lowering); - assert_eq!(output_types.len(), 1); - assert!(data_inst_def.inputs.len() <= 1); + // NOTE(eddyb) the ordering of some checks below is not purely aesthetic, + // if the types are invalid there could e.g. be disaggregation where it + // should never otherwise appear, so type checks should precede them. + let replacement_kind_and_inputs = if spv_inst.opcode == wk.OpVariable { + // HACK(eddyb) only needed because of potentially invalid SPIR-V. + let output_type = + spv_inst_lowering.disaggregated_output.unwrap_or_else(|| output_types[0]); let (_, var_data_type) = - self.lowerer.as_spv_ptr_type(output_types[0]).ok_or_else(|| { + self.lowerer.as_spv_ptr_type(output_type).ok_or_else(|| { LowerError(Diag::bug(["output type not an `OpTypePointer`".into()])) })?; + + assert!(spv_inst_lowering.disaggregated_output.is_none()); + + // FIXME(eddyb) this can be happen due to the optional initializer. + // FIXME(eddyb) lower the initializer to store(s) just after variables. + if !spv_inst_lowering.disaggregated_inputs.is_empty() { + return Ok(Transformed::Unchanged); + } + + assert_eq!(output_types.len(), 1); + assert!(data_inst_def.inputs.len() <= 1); + match self.lowerer.layout_of(var_data_type)? { TypeLayout::Concrete(concrete) if concrete.mem_layout.dyn_unit_stride.is_none() => { ( @@ -474,41 +498,15 @@ impl LowerFromSpvPtrInstsInFunc<'_> { } _ => return Ok(Transformed::Unchanged), } - } else if spv_inst.opcode == wk.OpLoad { - // FIXME(eddyb) expand into per-leaf accesses. + } else if spv_inst.opcode == wk.OpArrayLength { if disaggregated_output_or_inputs_during_lowering { - return Ok(Transformed::Unchanged); + return Err(LowerError(Diag::bug([format!( + "unexpected aggregate types in `{}`", + spv_inst.opcode.name() + ) + .into()]))); } - // FIXME(eddyb) support memory operands somehow. - if !spv_inst.imms.is_empty() { - return Ok(Transformed::Unchanged); - } - assert_eq!(data_inst_def.inputs.len(), 1); - - let ptr = data_inst_def.inputs[0]; - - let (ptr, offset) = flatten_offsets(ptr); - (QPtrOp::Load { offset }.into(), [ptr].into_iter().collect()) - } else if spv_inst.opcode == wk.OpStore { - // FIXME(eddyb) expand into per-leaf accesses. - if disaggregated_output_or_inputs_during_lowering { - return Ok(Transformed::Unchanged); - } - // FIXME(eddyb) support memory operands somehow. - if !spv_inst.imms.is_empty() { - return Ok(Transformed::Unchanged); - } - assert_eq!(data_inst_def.inputs.len(), 2); - - let ptr = data_inst_def.inputs[0]; - let value = data_inst_def.inputs[1]; - - let (ptr, offset) = flatten_offsets(ptr); - - (QPtrOp::Store { offset }.into(), [ptr, value].into_iter().collect()) - } else if spv_inst.opcode == wk.OpArrayLength { - assert!(!disaggregated_output_or_inputs_during_lowering); let field_idx = match spv_inst.imms[..] { [spv::Imm::Short(_, field_idx)] => field_idx, _ => unreachable!(), @@ -578,7 +576,13 @@ impl LowerFromSpvPtrInstsInFunc<'_> { ] .contains(&spv_inst.opcode) { - assert!(!disaggregated_output_or_inputs_during_lowering); + if disaggregated_output_or_inputs_during_lowering { + return Err(LowerError(Diag::bug([format!( + "unexpected aggregate types in `{}`", + spv_inst.opcode.name() + ) + .into()]))); + } // FIXME(eddyb) avoid erasing the "inbounds" qualifier. let base_ptr = data_inst_def.inputs[0]; @@ -626,7 +630,9 @@ impl LowerFromSpvPtrInstsInFunc<'_> { let step_data_inst = func_at_data_inst.reborrow().data_insts.define( cx, DataInstDef { - attrs: Default::default(), + // FIXME(eddyb) filter attributes into debuginfo and + // semantic, and understand the semantic ones. + attrs, form: cx.intern(DataInstFormDef { kind, output_types: [self.lowerer.qptr_type()].into_iter().collect(), @@ -658,8 +664,251 @@ impl LowerFromSpvPtrInstsInFunc<'_> { ptr = Value::DataInstOutput { inst: step_data_inst, output_idx: 0 }; } final_step.into_data_inst_kind_and_inputs(ptr) + } else if [wk.OpLoad, wk.OpStore].contains(&spv_inst.opcode) { + let ptr = data_inst_def.inputs[0]; + + // HACK(eddyb) only needed because of potentially invalid SPIR-V. + let type_of_ptr = match &spv_inst_lowering.disaggregated_inputs[..] { + [(range, _), ..] if range.start == 0 => None, + _ => Some(func.at(ptr).type_of(cx)), + }; + let (_, pointee_type) = type_of_ptr + .and_then(|type_of_ptr| self.lowerer.as_spv_ptr_type(type_of_ptr)) + .ok_or_else(|| { + LowerError(Diag::bug(["pointer input not an `OpTypePointer`".into()])) + })?; + + #[derive(Copy, Clone)] + enum Access { + Load(Type), + Store(Value), + } + + impl Access { + fn to_data_inst_form_and_extra_input( + self, + cx: &Context, + offset: i32, + ) -> (DataInstForm, Option) { + match self { + Access::Load(ty) => ( + cx.intern(DataInstFormDef { + kind: QPtrOp::Load { offset }.into(), + output_types: [ty].into_iter().collect(), + }), + None, + ), + Access::Store(value) => ( + cx.intern(DataInstFormDef { + kind: QPtrOp::Store { offset }.into(), + output_types: [].into_iter().collect(), + }), + Some(value), + ), + } + } + } + + enum Accesses> { + Single(Access), + AggregateLeaves { aggregate_type: Type, leaf_accesses: LLA }, + } + + let accesses = if spv_inst.opcode == wk.OpLoad { + assert!(spv_inst_lowering.disaggregated_inputs.is_empty()); + assert_eq!(data_inst_def.inputs.len(), 1); + + match spv_inst_lowering.disaggregated_output { + None => Accesses::Single(Access::Load(output_types[0])), + Some(aggregate_type) => Accesses::AggregateLeaves { + aggregate_type, + leaf_accesses: Either::Left( + output_types.iter().map(|&ty| Access::Load(ty)), + ), + }, + } + } else { + assert!(spv_inst_lowering.disaggregated_output.is_none()); + + match spv_inst_lowering.disaggregated_inputs[..] { + [] => { + assert_eq!(data_inst_def.inputs.len(), 2); + + Accesses::Single(Access::Store(data_inst_def.inputs[1])) + } + [(ref range, aggregate_type)] => { + assert_eq!(*range, 1..u32::try_from(data_inst_def.inputs.len()).unwrap()); + + Accesses::AggregateLeaves { + aggregate_type, + leaf_accesses: Either::Right( + data_inst_def.inputs[1..].iter().map(|&v| Access::Store(v)), + ), + } + } + _ => unreachable!(), + } + }; + + let type_of_access = |access| match access { + Access::Load(ty) => ty, + Access::Store(value) => func.at(value).type_of(cx), + }; + + let original_access_type = match accesses { + Accesses::Single(access) => type_of_access(access), + Accesses::AggregateLeaves { aggregate_type, .. } => aggregate_type, + }; + + if pointee_type != original_access_type { + return Err(LowerError(Diag::bug([ + "access type different from pointee type".into() + ]))); + } + + let (ptr, base_offset) = flatten_offsets(ptr); + + // FIXME(eddyb) support memory operands somehow. + if !spv_inst.imms.is_empty() { + return Ok(Transformed::Unchanged); + } + + match accesses { + Accesses::Single(access) => { + let (form, extra_input) = + access.to_data_inst_form_and_extra_input(cx, base_offset); + return Ok(Transformed::Changed(DataInstDef { + attrs, + form, + inputs: [ptr].into_iter().chain(extra_input).collect(), + })); + } + + // If this is an aggregate `OpLoad`/`OpStore`, we should generate + // one instruction per leaf, instead. + Accesses::AggregateLeaves { aggregate_type: _, mut leaf_accesses } => { + let mem_data_layout = match self.lowerer.layout_of(pointee_type)? { + TypeLayout::Concrete(mem) => mem, + _ => { + return Err(LowerError(Diag::bug([ + "by-value aggregate type without memory layout: ".into(), + pointee_type.into(), + ]))); + } + }; + + // HACK(eddyb) we have to buffer the details of the new + // instructions because we're iterating over the original + // one, and can't allocate the new `DataInst`s as we go. + let mut leaf_forms_and_extra_inputs = SmallVec::<[_; 4]>::new(); + mem_data_layout + .deeply_flatten_if( + base_offset, + // Whether `candidate_layout` is an aggregate (to recurse into). + &|candidate_layout| matches!( + &cx[candidate_layout.original_type].kind, + TypeKind::SpvInst { value_lowering: spv::ValueLowering::Disaggregate(_), .. } + ), + &mut |leaf_offset, leaf| { + let leaf_access = leaf_accesses.next().ok_or_else(|| { + LayoutError(Diag::bug([ + "`spv::lower` and `qptr::layout` disagree on aggregate leaves of ".into(), + pointee_type.into(), + ])) + })?; + let leaf_type = type_of_access(leaf_access); + if leaf_type != leaf.original_type { + return Err(LayoutError(Diag::bug([ + "aggregate leaf mismatch: `".into(), + leaf_type.into(), + "` vs `".into(), + leaf.original_type.into(), + "`".into() + ]))); + } + leaf_forms_and_extra_inputs.push(leaf_access.to_data_inst_form_and_extra_input(cx, leaf_offset)); + Ok(()) + }, + ) + .map_err(|LayoutError(err)| LowerError(err))?; + + if leaf_accesses.next().is_some() { + return Err(LowerError(Diag::bug([ + "`spv::lower` and `qptr::layout` disagree on aggregate leaves of " + .into(), + pointee_type.into(), + ]))); + } + + // HACK(eddyb) this is for `aggregate_load_to_leaf_loads`, + // which gets used later, to replace uses of one of the + // outputs ofthe original `OpLoad`, with uses of leaf loads. + let mut leaf_loads = if spv_inst.opcode == wk.OpLoad { + Some(SmallVec::with_capacity(leaf_forms_and_extra_inputs.len())) + } else { + None + }; + + // This is the point of no return: we're inserting several + // new instructions, and marking the old one for removal. + for (form, extra_input) in leaf_forms_and_extra_inputs { + let leaf_data_inst = func_at_data_inst.reborrow().data_insts.define( + cx, + DataInstDef { + // FIXME(eddyb) filter attributes into debuginfo and + // semantic, and understand the semantic ones. + attrs, + form, + inputs: [ptr].into_iter().chain(extra_input).collect(), + } + .into(), + ); + + // HACK(eddyb) can't really use helpers like `FuncAtMut::def`, + // due to the need to borrow `control_nodes` and `data_insts` + // at the same time - perhaps some kind of `FuncAtMut` position + // types for "where a list is in a parent entity" could be used + // to make this more ergonomic, although the potential need for + // an actual list entity of its own, should be considered. + let func = func_at_data_inst.reborrow().at(()); + match &mut func.control_nodes[parent_block].kind { + ControlNodeKind::Block { insts } => { + insts.insert_before(leaf_data_inst, data_inst, func.data_insts); + } + _ => unreachable!(), + } + + if let Some(leaf_loads) = &mut leaf_loads { + leaf_loads.push(leaf_data_inst); + } + } + self.remove_if_dead_inst_and_parent_block.push((data_inst, parent_block)); + if let Some(leaf_loads) = leaf_loads { + self.aggregate_load_to_leaf_loads.insert(data_inst, leaf_loads); + } + + // HACK(eddyb) this is a bit counter-intuitive (and wasteful), + // but we expect the original instruction to be removed as + // effectively unused, so this will only be kept *if that fails*. + return Err(LowerError(Diag::bug([ + "disaggregation of `OpLoad`/`OpStore` should've \ + removed the original instruction, but failed to" + .into(), + ]))); + } + } + } else if spv_inst.opcode == wk.OpCopyMemory { + // FIXME(eddyb) partially disaggregate (`OpTypeStruct` but not `OpTypeArray`). + return Ok(Transformed::Unchanged); } else if spv_inst.opcode == wk.OpBitcast { - assert!(!disaggregated_output_or_inputs_during_lowering); + if disaggregated_output_or_inputs_during_lowering { + return Err(LowerError(Diag::bug([format!( + "unexpected aggregate types in `{}`", + spv_inst.opcode.name() + ) + .into()]))); + } + assert_eq!(output_types.len(), 1); assert_eq!(data_inst_def.inputs.len(), 1); @@ -780,28 +1029,33 @@ impl LowerFromSpvPtrInstsInFunc<'_> { } } } + + // HACK(eddyb) this is a helper *only* for `transform_value_use` and + // `in_place_transform_control_node_def`, and should not be used elsewhere. + fn apply_value_replacements(&self, mut value: Value) -> Value { + while let Value::DataInstOutput { inst, output_idx } = value { + value = if let Some(&base_ptr) = self.noop_offsets_to_base_ptr.get(&inst) { + assert_eq!(output_idx, 0); + base_ptr + } else if let Some(leaf_loads) = self.aggregate_load_to_leaf_loads.get(&inst) { + Value::DataInstOutput { inst: leaf_loads[output_idx as usize], output_idx: 0 } + } else { + break; + }; + } + value + } } impl Transformer for LowerFromSpvPtrInstsInFunc<'_> { // NOTE(eddyb) it's important that this only gets invoked on already lowered // `Value`s, so we can rely on e.g. `noop_offsets_to_base_ptr` being filled. fn transform_value_use(&mut self, v: &Value) -> Transformed { - let mut v = *v; + let new_v = self.apply_value_replacements(*v); - let transformed = match v { - Value::DataInstOutput { inst, output_idx: 0 } => self - .noop_offsets_to_base_ptr - .get(&inst) - .copied() - .map_or(Transformed::Unchanged, Transformed::Changed), + self.add_value_uses(&[new_v]); - _ => Transformed::Unchanged, - }; - - transformed.apply_to(&mut v); - self.add_value_uses(&[v]); - - transformed + if *v == new_v { Transformed::Unchanged } else { Transformed::Changed(new_v) } } // HACK(eddyb) while we want to transform `DataInstDef`s, we can't inject @@ -841,18 +1095,7 @@ impl Transformer for LowerFromSpvPtrInstsInFunc<'_> { } if let QPtrOp::Offset(0) = op { - let mut base_ptr = new_def.inputs[0]; - if let Value::DataInstOutput { - inst: base_ptr_inst, - output_idx: 0, - } = base_ptr - { - if let Some(&base_ptr_base_ptr) = - self.noop_offsets_to_base_ptr.get(&base_ptr_inst) - { - base_ptr = base_ptr_base_ptr; - } - } + let base_ptr = self.apply_value_replacements(new_def.inputs[0]); self.noop_offsets_to_base_ptr .insert(func_at_inst.position, base_ptr); } diff --git a/src/spv/spec.rs b/src/spv/spec.rs index 0c97285c..d5dc8f9b 100644 --- a/src/spv/spec.rs +++ b/src/spv/spec.rs @@ -155,6 +155,7 @@ def_well_known! { OpLoad, OpStore, + OpCopyMemory, OpArrayLength, OpAccessChain, OpInBoundsAccessChain, From d4a944368a8af6357228ed9443b34921df4af036 Mon Sep 17 00:00:00 2001 From: Eduard-Mihai Burtescu Date: Sun, 1 Oct 2023 17:12:05 +0300 Subject: [PATCH 16/22] qptr/lower: expand `OpCopyMemory` into leaf loads+stores (in the absence of large arrays). --- src/qptr/lower.rs | 195 +++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 193 insertions(+), 2 deletions(-) diff --git a/src/qptr/lower.rs b/src/qptr/lower.rs index c0bba605..7515b03b 100644 --- a/src/qptr/lower.rs +++ b/src/qptr/lower.rs @@ -787,6 +787,8 @@ impl LowerFromSpvPtrInstsInFunc<'_> { // If this is an aggregate `OpLoad`/`OpStore`, we should generate // one instruction per leaf, instead. Accesses::AggregateLeaves { aggregate_type: _, mut leaf_accesses } => { + // FIXME(eddyb) this may need to automatically generate an + // intermediary `QPtrOp::BufferData` when accessing buffers. let mem_data_layout = match self.lowerer.layout_of(pointee_type)? { TypeLayout::Concrete(mem) => mem, _ => { @@ -898,8 +900,197 @@ impl LowerFromSpvPtrInstsInFunc<'_> { } } } else if spv_inst.opcode == wk.OpCopyMemory { - // FIXME(eddyb) partially disaggregate (`OpTypeStruct` but not `OpTypeArray`). - return Ok(Transformed::Unchanged); + if disaggregated_output_or_inputs_during_lowering { + return Err(LowerError(Diag::bug([format!( + "unexpected aggregate types in `{}`", + spv_inst.opcode.name() + ) + .into()]))); + } + + assert_eq!(data_inst_def.inputs.len(), 2); + + let dst_ptr = data_inst_def.inputs[0]; + let src_ptr = data_inst_def.inputs[1]; + + let (_, dst_pointee_type) = + self.lowerer.as_spv_ptr_type(func.at(dst_ptr).type_of(cx)).ok_or_else(|| { + LowerError(Diag::bug([ + "destination pointer input not an `OpTypePointer`".into() + ])) + })?; + let (_, src_pointee_type) = + self.lowerer.as_spv_ptr_type(func.at(src_ptr).type_of(cx)).ok_or_else(|| { + LowerError(Diag::bug(["source pointer input not an `OpTypePointer`".into()])) + })?; + + if dst_pointee_type != src_pointee_type { + return Err(LowerError(Diag::bug([ + "copy destination pointee type different from source pointee type".into(), + ]))); + } + + // FIXME(eddyb) this may need to automatically generate an + // intermediary `QPtrOp::BufferData` when accessing buffers. + let mem_data_layout = match self.lowerer.layout_of(src_pointee_type)? { + TypeLayout::Concrete(mem) => mem, + _ => { + return Err(LowerError(Diag::bug([ + "`OpCopyMemory` of data with non-memory type: ".into(), + src_pointee_type.into(), + ]))); + } + }; + + let (dst_ptr, dst_base_offset) = flatten_offsets(dst_ptr); + let (src_ptr, src_base_offset) = flatten_offsets(src_ptr); + + // FIXME(eddyb) support memory operands somehow. + if !spv_inst.imms.is_empty() { + return Ok(Transformed::Unchanged); + } + + // HACK(eddyb) this is speculative, so we just give up if we hit + // some situation we don't currently support - ideally, there would + // be an *untyped* `qptr.copy`, but that is harder to support overall. + // HACK(eddyb) this is a `try {...}`-like use of a closure. + let try_gather_leaf_offsets_and_types = || { + struct UnsupportedLargeArray; + let recurse_into_layout = |layout: &MemTypeLayout| { + let aggregate_shape = match &cx[layout.original_type].kind { + TypeKind::SpvInst { + value_lowering: spv::ValueLowering::Disaggregate(aggregate_shape), + .. + } => aggregate_shape, + _ => return Ok(false), + }; + match *aggregate_shape { + spv::AggregateShape::Struct { .. } => Ok(true), + + // HACK(eddyb) 16 leaves allows for a 4x4 matrix, even + // when represented as e.g. `[f32; 16]` or `[[f32; 4]; 4]` + // (this comparison gets more complex when accounting + // for vectors, e.g. `[f32x4; 4]`, which is only 4 leaves), + // but ideally most types accepted here will be even + // smaller arrays (which could've e.g. been structs). + // FIXME(eddyb) larger arrays should lower to loops that + // copy a small number of leaves per iteration, or even + // some general-purpose `qptr.copy`, to avoid generating + // amounts of IR that scale with the array length, which + // (unlike struct fields) can be arbitrarily large. + spv::AggregateShape::Array { total_leaf_count, .. } => { + if total_leaf_count <= 16 { + Ok(true) + } else { + Err(UnsupportedLargeArray) + } + } + } + }; + + // HACK(eddyb) buffering the details of the instructions we'll + // be generating, because we don't know ahead of time whether we + // even want to expand the `OpCopyMemory`, at all. + let mut leaf_offsets_and_types = SmallVec::<[_; 8]>::new(); + mem_data_layout + .deeply_flatten_if( + 0, + &|candidate_layout| recurse_into_layout(candidate_layout).unwrap_or(false), + &mut |leaf_offset, leaf| { + // FIMXE(eddyb) ideally this would not be computed twice. + recurse_into_layout(leaf).map_err(|UnsupportedLargeArray| { + // HACK(eddyb) not an error, just stopping traversal. + LayoutError(Diag::bug([])) + })?; + + // HACK(eddyb) `deeply_flatten_if` takes a base offset, + // but we have two, so we need our own overflow checks. + if dst_base_offset.checked_add(leaf_offset).is_none() + || src_base_offset.checked_add(leaf_offset).is_none() + { + // HACK(eddyb) not an error, just stopping traversal. + return Err(LayoutError(Diag::bug([]))); + } + + leaf_offsets_and_types.push((leaf_offset, leaf.original_type)); + + Ok(()) + }, + ) + .ok()?; + Some(leaf_offsets_and_types) + }; + let leaf_offsets_and_types = match try_gather_leaf_offsets_and_types() { + Some(leaf_offsets_and_types) => leaf_offsets_and_types, + None => return Ok(Transformed::Unchanged), + }; + + // This is the point of no return: we're inserting several + // new instructions, and marking the old one for removal. + for (leaf_offset, leaf_type) in leaf_offsets_and_types { + let leaf_load_data_inst = func_at_data_inst.reborrow().data_insts.define( + cx, + DataInstDef { + // FIXME(eddyb) filter attributes into debuginfo and + // semantic, and understand the semantic ones. + attrs, + form: cx.intern(DataInstFormDef { + kind: QPtrOp::Load { + offset: src_base_offset.checked_add(leaf_offset).unwrap(), + } + .into(), + output_types: [leaf_type].into_iter().collect(), + }), + inputs: [src_ptr].into_iter().collect(), + } + .into(), + ); + let leaf_store_data_inst = func_at_data_inst.reborrow().data_insts.define( + cx, + DataInstDef { + // FIXME(eddyb) filter attributes into debuginfo and + // semantic, and understand the semantic ones. + attrs, + form: cx.intern(DataInstFormDef { + kind: QPtrOp::Store { + offset: dst_base_offset.checked_add(leaf_offset).unwrap(), + } + .into(), + output_types: [].into_iter().collect(), + }), + inputs: [ + dst_ptr, + Value::DataInstOutput { inst: leaf_load_data_inst, output_idx: 0 }, + ] + .into_iter() + .collect(), + } + .into(), + ); + + // HACK(eddyb) can't really use helpers like `FuncAtMut::def`, + // due to the need to borrow `control_nodes` and `data_insts` + // at the same time - perhaps some kind of `FuncAtMut` position + // types for "where a list is in a parent entity" could be used + // to make this more ergonomic, although the potential need for + // an actual list entity of its own, should be considered. + let func = func_at_data_inst.reborrow().at(()); + match &mut func.control_nodes[parent_block].kind { + ControlNodeKind::Block { insts } => { + insts.insert_before(leaf_load_data_inst, data_inst, func.data_insts); + insts.insert_before(leaf_store_data_inst, data_inst, func.data_insts); + } + _ => unreachable!(), + } + } + self.remove_if_dead_inst_and_parent_block.push((data_inst, parent_block)); + + // HACK(eddyb) this is a bit counter-intuitive (and wasteful), + // but we expect the original instruction to be removed as + // effectively unused, so this will only be kept *if that fails*. + return Err(LowerError(Diag::bug(["disaggregation of `OpCopyMemory` should've \ + removed the original instruction, but failed to" + .into()]))); } else if spv_inst.opcode == wk.OpBitcast { if disaggregated_output_or_inputs_during_lowering { return Err(LowerError(Diag::bug([format!( From 830b4e8650ba2b2a90b755d85664202fc51c78ca Mon Sep 17 00:00:00 2001 From: Eduard-Mihai Burtescu Date: Thu, 5 Oct 2023 13:26:51 +0300 Subject: [PATCH 17/22] print: attempt to improve `GlobalVar` printing by using named arguments. --- README.md | 2 +- src/print/mod.rs | 257 +++++++++++++++++++++++++---------------------- 2 files changed, 138 insertions(+), 121 deletions(-) diff --git a/README.md b/README.md index 582da461..30e848a1 100644 --- a/README.md +++ b/README.md @@ -135,7 +135,7 @@ fn main() -> @location(0) i32 { ```cxx #[spv.Decoration.Flat] #[spv.Decoration.Location(Location: 0)] -global_var GV0 in spv.StorageClass.Output: s32 +global_var GV0(spv.StorageClass.Output): s32 func F0() { loop(v0: s32 <- 1s32, v1: s32 <- 1s32) { diff --git a/src/print/mod.rs b/src/print/mod.rs index 553f21ef..c9e22496 100644 --- a/src/print/mod.rs +++ b/src/print/mod.rs @@ -2566,139 +2566,148 @@ impl Print for GlobalVarDecl { let wk = &spv::spec::Spec::get().well_known; - // HACK(eddyb) get the pointee type from SPIR-V `OpTypePointer`, but - // ideally the `GlobalVarDecl` would hold that type itself. - let type_ascription_suffix = match &printer.cx[*type_of_ptr_to].kind { - TypeKind::QPtr if shape.is_some() => match shape.unwrap() { - qptr::shapes::GlobalVarShape::Handles { handle, fixed_count } => { - let handle = match handle { - qptr::shapes::Handle::Opaque(ty) => ty.print(printer), - qptr::shapes::Handle::Buffer(addr_space, buf) => pretty::Fragment::new([ - printer.declarative_keyword_style().apply("buffer").into(), - pretty::join_comma_sep( - "(", - [ - addr_space.print(printer), - pretty::Fragment::new([ - printer.pretty_named_argument_prefix("size"), - pretty::Fragment::new( - Some(buf.fixed_base.size) - .filter(|&base_size| { - base_size > 0 || buf.dyn_unit_stride.is_none() - }) - .map(|base_size| { - printer - .numeric_literal_style() - .apply(base_size.to_string()) - .into() - }) - .into_iter() - .chain(buf.dyn_unit_stride.map(|stride| { - pretty::Fragment::new([ - "N × ".into(), - printer - .numeric_literal_style() - .apply(stride.to_string()), - ]) - })) - .intersperse_with(|| " + ".into()), - ), - ]), - pretty::Fragment::new([ - printer.pretty_named_argument_prefix("align"), - printer - .numeric_literal_style() - .apply(buf.fixed_base.align.to_string()) - .into(), - ]), - ], - ")", - ), - ]), - }; + // HACK(eddyb) to avoid too many syntax variations, most details (other + // than the type, if present) use named arguments in `GV123(...)`. + let mut details = SmallVec::<[_; 4]>::new(); - let handles = if fixed_count.map_or(0, |c| c.get()) == 1 { - handle - } else { - pretty::Fragment::new([ - "[".into(), - fixed_count - .map(|count| { - pretty::Fragment::new([ - printer.numeric_literal_style().apply(count.to_string()), - " × ".into(), - ]) - }) - .unwrap_or_default(), - handle, - "]".into(), - ]) - }; - pretty::join_space(":", [handles]) - } - qptr::shapes::GlobalVarShape::UntypedData(mem_layout) => pretty::Fragment::new([ - " ".into(), - printer.declarative_keyword_style().apply("layout").into(), - pretty::join_comma_sep( - "(", - [ - pretty::Fragment::new([ - printer.pretty_named_argument_prefix("size"), - printer - .numeric_literal_style() - .apply(mem_layout.size.to_string()) - .into(), - ]), - pretty::Fragment::new([ - printer.pretty_named_argument_prefix("align"), - printer - .numeric_literal_style() - .apply(mem_layout.align.to_string()) - .into(), - ]), - ], - ")", - ), - ]), - qptr::shapes::GlobalVarShape::TypedInterface(ty) => { - printer.pretty_type_ascription_suffix(ty) - } - }, + match addr_space { + AddrSpace::Handles => {} + AddrSpace::SpvStorageClass(_) => { + details.push(addr_space.print(printer)); + } + } + + // FIXME(eddyb) should this be a helper on `Printer`? + let num_lit = |x: u32| printer.numeric_literal_style().apply(format!("{x}")).into(); + + // FIXME(eddyb) should the pointer type be shown as something like + // `&GV123: OpTypePointer(..., T123)` *after* the variable definition? + // (but each reference can technically have a different type...) + let (qptr_shape, spv_ptr_pointee_type) = match &printer.cx[*type_of_ptr_to].kind { + TypeKind::QPtr => (shape.as_ref(), None), + + // HACK(eddyb) get the pointee type from SPIR-V `OpTypePointer`, but + // ideally the `GlobalVarDecl` would hold that type itself. TypeKind::SpvInst { spv_inst, type_and_const_inputs, .. } if spv_inst.opcode == wk.OpTypePointer => { match type_and_const_inputs[..] { - [TypeOrConst::Type(ty)] => printer.pretty_type_ascription_suffix(ty), - _ => unreachable!(), + [TypeOrConst::Type(pointee_type)] => (None, Some(pointee_type)), + _ => (None, None), } } - _ => pretty::Fragment::new([ - ": ".into(), - printer.error_style().apply("pointee_type_of").into(), - "(".into(), - type_of_ptr_to.print(printer), - ")".into(), - ]), + + _ => (None, None), }; - let addr_space_suffix = match addr_space { - AddrSpace::Handles => pretty::Fragment::default(), - AddrSpace::SpvStorageClass(_) => { - pretty::Fragment::new([" in ".into(), addr_space.print(printer)]) + let ascribe_type = match qptr_shape { + Some(qptr::shapes::GlobalVarShape::Handles { handle, fixed_count }) => { + let handle = match handle { + qptr::shapes::Handle::Opaque(ty) => ty.print(printer), + qptr::shapes::Handle::Buffer(addr_space, buf) => pretty::Fragment::new([ + printer.declarative_keyword_style().apply("buffer").into(), + pretty::join_comma_sep( + "(", + [ + addr_space.print(printer), + pretty::Fragment::new([ + printer.pretty_named_argument_prefix("size"), + pretty::Fragment::new( + [ + Some(buf.fixed_base.size) + .filter(|&base_size| { + base_size > 0 || buf.dyn_unit_stride.is_none() + }) + .map(num_lit), + buf.dyn_unit_stride.map(|stride| { + pretty::Fragment::new([ + "N × ".into(), + num_lit(stride.get()), + ]) + }), + ] + .into_iter() + .flatten() + .intersperse_with(|| " + ".into()), + ), + ]), + pretty::Fragment::new([ + printer.pretty_named_argument_prefix("align"), + num_lit(buf.fixed_base.align), + ]), + ], + ")", + ), + ]), + }; + + let handles = if fixed_count.map_or(0, |c| c.get()) == 1 { + handle + } else { + pretty::Fragment::new([ + "[".into(), + fixed_count + .map(|count| { + pretty::Fragment::new([num_lit(count.get()), " × ".into()]) + }) + .unwrap_or_default(), + handle, + "]".into(), + ]) + }; + Some(handles) } + Some(qptr::shapes::GlobalVarShape::UntypedData(mem_layout)) => { + details.extend([ + pretty::Fragment::new([ + printer.pretty_named_argument_prefix("size"), + num_lit(mem_layout.size), + ]), + pretty::Fragment::new([ + printer.pretty_named_argument_prefix("align"), + num_lit(mem_layout.align), + ]), + ]); + None + } + Some(qptr::shapes::GlobalVarShape::TypedInterface(ty)) => Some(ty.print(printer)), + + None => Some(match spv_ptr_pointee_type { + Some(ty) => ty.print(printer), + None => pretty::Fragment::new([ + printer.error_style().apply("pointee_type_of").into(), + "(".into(), + type_of_ptr_to.print(printer), + ")".into(), + ]), + }), }; - let header = pretty::Fragment::new([addr_space_suffix, type_ascription_suffix]); - let maybe_rhs = match def { + let import = match def { + // FIXME(eddyb) deduplicate with `FuncDecl`, and maybe consider + // putting the import *before* the declaration, to end up with: + // import "..." + // as global_var GV... DeclDef::Imported(import) => Some(import.print(printer)), DeclDef::Present(GlobalVarDefBody { initializer }) => { - // FIXME(eddyb) `global_varX in AS: T = Y` feels a bit wonky for - // the initializer, but it's cleaner than obvious alternatives. - initializer.map(|initializer| initializer.print(printer)) + if let Some(initializer) = initializer { + details.push(pretty::Fragment::new([ + printer.pretty_named_argument_prefix("init"), + initializer.print(printer), + ])); + } + None } }; - let body = maybe_rhs.map(|rhs| pretty::Fragment::new(["= ".into(), rhs])); - let def_without_name = pretty::Fragment::new([header, pretty::join_space("", body)]); + let def_without_name = pretty::Fragment::new( + [ + (!details.is_empty()).then(|| pretty::join_comma_sep("(", details, ")")), + ascribe_type.map(|ty| pretty::join_space(":", [ty])), + import.map(|import| pretty::Fragment::new([" = ".into(), import])), + ] + .into_iter() + .flatten(), + ); AttrsAndDef { attrs: attrs.print(printer), def_without_name } } @@ -2754,9 +2763,17 @@ impl Print for FuncDecl { ]); let def_without_name = match def { - DeclDef::Imported(import) => { - pretty::Fragment::new([sig, " = ".into(), import.print(printer)]) - } + // FIXME(eddyb) deduplicate with `GlobalVarDecl`, and maybe consider + // putting the import *before* the declaration, to end up with: + // import "..." + // as func F... + DeclDef::Imported(import) => pretty::Fragment::new([ + sig, + pretty::join_space( + "", + [pretty::Fragment::new(["= ".into(), import.print(printer)])], + ), + ]), // FIXME(eddyb) this can probably go into `impl Print for FuncDefBody`. DeclDef::Present(def) => pretty::Fragment::new([ From c61d0968b9138f2582bf1d4b2778b15771d09ed1 Mon Sep 17 00:00:00 2001 From: Eduard-Mihai Burtescu Date: Tue, 3 Oct 2023 10:58:48 +0300 Subject: [PATCH 18/22] [WIP] GV init disaggregate --- src/lib.rs | 28 +++- src/print/mod.rs | 34 ++++- src/qptr/const_data.rs | 236 ++++++++++++++++++++++++++++++++ src/qptr/lower.rs | 117 ++++++++++++++-- src/qptr/mod.rs | 1 + src/spv/lift.rs | 89 +++++++++++- src/spv/lower.rs | 297 +++++++++++++++++++++-------------------- src/spv/mod.rs | 45 +++++++ src/transform.rs | 25 +++- src/visit.rs | 27 +++- 10 files changed, 730 insertions(+), 169 deletions(-) create mode 100644 src/qptr/const_data.rs diff --git a/src/lib.rs b/src/lib.rs index 31291009..83403c74 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -174,7 +174,7 @@ pub mod vector; use smallvec::SmallVec; use std::borrow::Cow; -use std::collections::BTreeSet; +use std::collections::{BTreeMap, BTreeSet}; use std::rc::Rc; // HACK(eddyb) work around the lack of `FxIndex{Map,Set}` type aliases elsewhere. @@ -694,10 +694,32 @@ pub enum AddrSpace { } /// The body of a [`GlobalVar`] definition. +// +// FIXME(eddyb) make "interface variables" go through imports, not definitions. #[derive(Clone)] pub struct GlobalVarDefBody { - /// If `Some`, the global variable will start out with the specified value. - pub initializer: Option, + pub initializer: Option, +} + +/// Initial contents for a [`GlobalVar`] definition. +// +// FIXME(eddyb) add special cases for for undef/zeroed/etc. +// FIXME(eddyb) consider renaming this to `ConstData` or `ConstBlob`? +#[derive(Clone)] +pub enum GlobalVarInit { + /// Single valid (constant) value (see [`Value`] docs for valid types). + // + // FIXME(eddyb) does this need to be its own case at all? + Direct(Const), + + /// SPIR-V "aggregate" (`OpTypeStruct`/`OpTypeArray`), represented as its + /// non-aggregate leaves (i.e. it's disaggregated, as per [`Value`] docs). + SpvAggregate { ty: Type, leaves: SmallVec<[Const; 4]> }, + + /// Non-overlapping multiple values, placed at explicit offsets. + // + // FIXME(eddyb) use a more efficient representation, like miri's. + Composite { offset_to_value: BTreeMap }, } /// Entity handle for a [`FuncDecl`](crate::FuncDecl) (a function). diff --git a/src/print/mod.rs b/src/print/mod.rs index c9e22496..f44ad9bf 100644 --- a/src/print/mod.rs +++ b/src/print/mod.rs @@ -29,8 +29,8 @@ use crate::{ ControlRegionDef, ControlRegionInputDecl, DataInst, DataInstDef, DataInstForm, DataInstFormDef, DataInstKind, DeclDef, Diag, DiagLevel, DiagMsgPart, EntityListIter, ExportKey, Exportee, Func, FuncDecl, FuncParam, FxIndexMap, FxIndexSet, GlobalVar, GlobalVarDecl, GlobalVarDefBody, - Import, Module, ModuleDebugInfo, ModuleDialect, OrdAssertEq, SelectionKind, Type, TypeDef, - TypeKind, TypeOrConst, Value, + GlobalVarInit, Import, Module, ModuleDebugInfo, ModuleDialect, OrdAssertEq, SelectionKind, + Type, TypeDef, TypeKind, TypeOrConst, Value, }; use arrayvec::ArrayVec; use itertools::Either; @@ -2726,6 +2726,36 @@ impl Print for AddrSpace { } } +impl Print for GlobalVarInit { + type Output = pretty::Fragment; + fn print(&self, printer: &Printer<'_>) -> pretty::Fragment { + match self { + GlobalVarInit::Direct(ct) => ct.print(printer), + // FIXME(eddyb) should this be recursive? + GlobalVarInit::SpvAggregate { ty, leaves } => pretty::Fragment::new([ + pretty::join_comma_sep("(", leaves.iter().map(|v| v.print(printer)), ")"), + printer.pretty_type_ascription_suffix(*ty), + ]), + GlobalVarInit::Composite { offset_to_value } => pretty::join_comma_sep( + "{", + offset_to_value + .iter() + .map(|(&offset, &ct)| { + pretty::Fragment::new([ + printer.numeric_literal_style().apply(format!("{offset}")).into(), + " => ".into(), + ct.print(printer), + ]) + }) + .map(|entry| { + pretty::Fragment::new([pretty::Node::ForceLineSeparation.into(), entry]) + }), + "}", + ), + } + } +} + impl Print for FuncDecl { type Output = AttrsAndDef; fn print(&self, printer: &Printer<'_>) -> AttrsAndDef { diff --git a/src/qptr/const_data.rs b/src/qptr/const_data.rs new file mode 100644 index 00000000..50944997 --- /dev/null +++ b/src/qptr/const_data.rs @@ -0,0 +1,236 @@ +//! Constant data efficiently mixing concrete bytes with symbolic values. + +use itertools::Itertools; +use smallvec::SmallVec; +use std::collections::BTreeMap; +use std::iter; +use std::num::NonZeroU32; +use std::ops::Range; + +/// Constant data "blob" or "chunk", where each byte can be part of: +/// - uninitialized areas (e.g. SPIR-V `OpUndef`) +/// - concrete data (i.e. `u8` values) +/// - symbolic values of type `V` (spanning some number of bytes) +/// +/// This is similar to (and inspired by), [`rustc`'s `mir::interpret::Allocation`]( +/// https://rustc-dev-guide.rust-lang.org/const-eval/interpret.html#memory), +/// which only has abstract pointers as symbolic values, encoded as "relocations" +/// (i.e. concrete data contains the respective offset for each abstract pointer, +/// whereas here the symbolic values are completely disjoint with concrete data). +pub struct ConstData { + /// The bit `init[i / 64] & (1 << (i % 64))` is set iff byte offset `i` is + /// initialized, either with concrete data, or as part of a symbolic value. + init: Vec, + + /// Concrete data bytes, with each byte only used when `init` indicates + /// it is initialized *and* no symbolic value overlaps it. Unused bytes can + /// have any values in `bytes`, as they're guaranteed to be always ignored. + data: Vec, + + /// Non-overlapping set of symbolic `V` values, forming an "overlay" on top + /// of the concrete data bytes, with `syms[offset] = (size, value)` + /// indicating bytes `offset..(offset + size)` are occupied by `value`. + syms: BTreeMap, + + /// Largest symbolic value size, i.e. `syms.values().map(|(size, _)| size).max()`. + // + // FIXME(eddyb) this is only needed to help with scanning overlaps in `syms`, + // and because there is no inherent limit on the size of symbolic values. + max_sym_size: NonZeroU32, +} + +/// One uniform "slice" of a `ConstData` (*not* mixing value categories). +#[derive(Clone)] +pub enum Part<'a, V> { + Uninit { + size: u32, + }, + Data(&'a [u8]), + Symbolic { + size: NonZeroU32, + /// This is only the full `value` if `slice == 0..size`. + slice: Range, + value: V, + }, +} + +/// Error type for write operations, emitted when they would otherwise cause a +/// partial overwrite of a symbolic value, if allowed to take effect. +#[derive(Debug)] +pub struct PartialSymbolicOverlap { + pub offsets: Range, +} + +// FIXME(eddyb) come up with a nicer abstraction for bitvecs, or use a crate. +fn bitrange_word_chunks(range: Range) -> (Range, impl Iterator>) { + let words = (range.start / 64)..(range.end.div_ceil(64)); + ( + (words.start as usize)..(words.end as usize), + words.map(move |i| { + (((i * 64).clamp(range.start, range.end) % 64) as u8) + ..((((i + 1) * 64).clamp(range.start, range.end) % 64) as u8) + }), + ) +} + +impl ConstData { + pub fn new(size: u32) -> Self { + let size = size as usize; + Self { + init: vec![0; size.div_ceil(64)], + data: vec![0; size], + syms: BTreeMap::new(), + max_sym_size: NonZeroU32::new(1).unwrap(), + } + } + + pub fn size(&self) -> u32 { + self.data.len() as u32 + } + + pub fn read(&self, range: Range) -> impl Iterator> { + // HACK(eddyb) trigger bounds-checking panics. + let _ = &self.data[(range.start as usize)..(range.end as usize)]; + + // HACK(eddyb) the range has to be extended backwards, because a partial + // overlap could exit, i.e. `range.start` being in the middle of a value, + // but then irrelevant values have to be ignored. + let mut syms = self + .syms + .range((range.start - (self.max_sym_size.get() - 1))..range.end) + .map(|(&offset, &(size, value))| (offset..(offset + size.get()), value)) + .peekable(); + while let Some((sym_range, _)) = syms.peek() { + if sym_range.end > range.start { + break; + } + syms.next().unwrap(); + } + + let mut part_start = range.start; + iter::from_fn(move || { + if part_start >= range.end { + return None; + } + let next_sym_range = syms.peek().cloned().map_or(range.end..range.end, |(r, _)| r); + + let max_part_end = if next_sym_range.contains(&part_start) { + next_sym_range.end + } else { + next_sym_range.start + }; + // FIXME(eddyb) come up with a nicer abstraction for bitvecs, or use a crate. + let (part_is_init, part_size) = { + let (words, word_bitslices) = bitrange_word_chunks(part_start..max_part_end); + self.init[words] + .iter() + .zip_eq(word_bitslices) + .flat_map(|(&word, word_bitslice)| { + let sliced_word = + (word >> word_bitslice.start) & (!0 >> (64 - word_bitslice.end)); + let first = (sliced_word & 1) != 0; + let same_run = if first { + sliced_word.trailing_ones() + } else { + sliced_word.trailing_zeros() + }; + [(first, same_run)] + .into_iter() + .chain((same_run < word_bitslice.len() as u32).then_some((!first, 0))) + }) + .coalesce(|(a, a_run), (b, b_run)| { + if a == b { Ok((a, a_run + b_run)) } else { Err(((a, a_run), (b, b_run))) } + }) + .next() + .unwrap() + }; + + let part_end = part_start + part_size; + let part = if !part_is_init { + Part::Uninit { size: part_size } + } else if next_sym_range.contains(&part_start) { + let (sym_range, value) = syms.next().unwrap(); + // HACK(eddyb) ensure slicing is caused by `range`, *not* `init`. + assert_eq!( + part_start..part_end, + sym_range.start.clamp(range.start, range.end) + ..sym_range.end.clamp(range.start, range.end) + ); + Part::Symbolic { + size: NonZeroU32::new(sym_range.len() as u32).unwrap(), + slice: (part_start - sym_range.start)..(part_end - sym_range.start), + value, + } + } else { + Part::Data(&self.data[(part_start as usize)..(part_end as usize)]) + }; + part_start = part_end; + Some(part) + }) + } + + /// Helper for `write_bytes` and `write_symbolic`, which only modifies `self` + /// (removing fully overwritten symbolic values, and setting `init` bits), + /// when it can guarantee it will return `Ok(())` (i.e. after error checks). + fn try_init(&mut self, range: Range) -> Result<(), PartialSymbolicOverlap> { + // HACK(eddyb) trigger bounds-checking panics. + let _ = &self.data[(range.start as usize)..(range.end as usize)]; + + // HACK(eddyb) the range has to be extended backwards, because a partial + // overlap could exit, i.e. `range.start` being in the middle of a value, + // but then irrelevant values have to be ignored. + let syms_ranges = self + .syms + .range((range.start - (self.max_sym_size.get() - 1))..range.end) + .map(|(&offset, &(size, _))| offset..(offset + size.get())); + + // FIXME(eddyb) this is a bit inefficient but we don't have + // cursors, so we have to buffer the `BTreeMap` keys here. + let mut full_overwritten_sym_offsets = SmallVec::<[u32; 16]>::new(); + for sym_range in syms_ranges { + let overlap = sym_range.start.clamp(range.start, range.end) + ..sym_range.end.clamp(range.start, range.end); + if overlap.is_empty() { + continue; + } + if overlap == sym_range { + full_overwritten_sym_offsets.push(sym_range.start); + } else { + return Err(PartialSymbolicOverlap { offsets: overlap }); + } + } + for offset in full_overwritten_sym_offsets { + self.syms.remove(&offset); + } + + // FIXME(eddyb) come up with a nicer abstraction for bitvecs, or use a crate. + { + let (words, word_bitslices) = bitrange_word_chunks(range); + for (word, word_bitslice) in self.init[words].iter_mut().zip(word_bitslices) { + *word |= (!0 << word_bitslice.start) & (!0 >> (64 - word_bitslice.end)); + } + } + + Ok(()) + } + + pub fn write_bytes(&mut self, offset: u32, bytes: &[u8]) -> Result<(), PartialSymbolicOverlap> { + let range = offset..(offset + bytes.len() as u32); + self.try_init(range.clone())?; + self.data[(range.start as usize)..(range.end as usize)].copy_from_slice(bytes); + Ok(()) + } + + // FIXME(eddyb) should this take an offset range instead? + pub fn write_symbolic( + &mut self, + offset: u32, + size: NonZeroU32, + value: V, + ) -> Result<(), PartialSymbolicOverlap> { + let range = offset..(offset + size.get()); + self.try_init(range.clone())?; + self.syms.insert(offset, (size, value)); + Ok(()) + } +} diff --git a/src/qptr/lower.rs b/src/qptr/lower.rs index 7515b03b..7c674da2 100644 --- a/src/qptr/lower.rs +++ b/src/qptr/lower.rs @@ -9,13 +9,14 @@ use crate::transform::{InnerInPlaceTransform, Transformed, Transformer}; use crate::{ spv, AddrSpace, AttrSetDef, Const, ConstDef, ConstKind, Context, ControlNode, ControlNodeKind, DataInst, DataInstDef, DataInstForm, DataInstFormDef, DataInstKind, DeclDef, Diag, - EntityOrientedDenseMap, FuncDecl, GlobalVarDecl, OrdAssertEq, Type, TypeKind, TypeOrConst, - Value, + EntityOrientedDenseMap, FuncDecl, GlobalVarDecl, GlobalVarInit, OrdAssertEq, Type, TypeKind, + TypeOrConst, Value, }; use itertools::Either; use rustc_hash::FxHashMap; use smallvec::SmallVec; use std::cell::Cell; +use std::collections::BTreeMap; use std::mem; use std::num::NonZeroU32; use std::rc::Rc; @@ -126,20 +127,116 @@ impl<'a> LowerFromSpvPtrs<'a> { global_var_decl.addr_space = AddrSpace::Handles; } } + + match &mut global_var_decl.def { + DeclDef::Imported(_) => {} + DeclDef::Present(global_var_def_body) => match &global_var_def_body.initializer { + None | Some(GlobalVarInit::Direct(_)) => {} + + Some(GlobalVarInit::Composite { .. }) => { + global_var_decl.attrs.push_diag( + &self.cx, + Diag::bug([ + "unexpected `GlobalVarInit::Composite` (already lowered?)".into() + ]), + ); + } + + Some(GlobalVarInit::SpvAggregate { ty, leaves }) => { + let lowered_initializer = self + .layout_of(*ty) + .and_then(|layout| match layout { + // FIXME(eddyb) consider bad interactions with "interface blocks"? + TypeLayout::Handle(_) | TypeLayout::HandleArray(..) => { + Err(LowerError(Diag::bug(["handles are not aggregates".into()]))) + } + TypeLayout::Concrete(layout) => Ok(layout), + }) + .and_then(|aggregate_layout| { + let mut leaf_values = leaves.iter().copied(); + let mut offset_to_value = BTreeMap::new(); + aggregate_layout + .deeply_flatten_if( + 0, + // Whether `candidate_layout` is an aggregate (to recurse into). + &|candidate_layout| { + matches!( + &self.cx[candidate_layout.original_type].kind, + TypeKind::SpvInst { + value_lowering: spv::ValueLowering::Disaggregate(_), + .. + } + ) + }, + &mut |leaf_offset, leaf| { + let leaf_offset = + u32::try_from(leaf_offset).ok().ok_or_else(|| { + LayoutError(Diag::bug([format!( + "negative initializer leaf offset {leaf_offset}" + ) + .into()])) + })?; + + let leaf_value = leaf_values.next().ok_or_else(|| { + LayoutError(Diag::bug([ + "fewer initializer leaves than layout".into(), + ])) + })?; + + // FIXME(eddyb) should this compare only size/shape? + let expected_ty = leaf.original_type; + let found_ty = self.cx[leaf_value].ty; + if expected_ty != found_ty { + return Err(LayoutError(Diag::bug([ + "initializer leaf type mismatch: expected `".into(), + expected_ty.into(), + "`, found `".into(), + found_ty.into(), + "`".into(), + ]))); + } + + offset_to_value.insert(leaf_offset, leaf_value); + + Ok(()) + }, + ) + .map_err(|LayoutError(e)| LowerError(e))?; + + if leaf_values.next().is_some() { + return Err(LowerError(Diag::bug([ + "more initializer leaves than layout".into(), + ]))); + } + + Ok(GlobalVarInit::Composite { offset_to_value }) + }); + match lowered_initializer { + Ok(initializer) => { + global_var_def_body.initializer = Some(initializer); + } + Err(LowerError(e)) => { + global_var_decl.attrs.push_diag(&self.cx, e); + } + } + } + }, + } + + // HACK(eddyb) in case anything goes wrong, we want to keep `OpTypePointer`. + let original_type_of_ptr_to = global_var_decl.type_of_ptr_to; + + EraseSpvPtrs { lowerer: self }.in_place_transform_global_var_decl(global_var_decl); + match shape_result { Ok(shape) => { global_var_decl.shape = Some(shape); - - // HACK(eddyb) this should handle shallow `QPtr` in the initializer, but - // typed initializers should be replaced with miri/linker-style ones. - // FIXME(eddyb) this is even worse now, with disaggregation, - // the initializer should be disaggregated leaves, which then - // need to flattened into a miri-like representation, or at least - // have offsets assigned to each leaf (for `qptr::lift` to use). - EraseSpvPtrs { lowerer: self }.in_place_transform_global_var_decl(global_var_decl); } Err(LowerError(e)) => { global_var_decl.attrs.push_diag(&self.cx, e); + + // HACK(eddyb) effectively undoes `EraseSpvPtrs` for one field. + global_var_decl.type_of_ptr_to = original_type_of_ptr_to; } } } diff --git a/src/qptr/mod.rs b/src/qptr/mod.rs index 65221c93..513bcb47 100644 --- a/src/qptr/mod.rs +++ b/src/qptr/mod.rs @@ -13,6 +13,7 @@ use std::rc::Rc; // NOTE(eddyb) all the modules are declared here, but they're documented "inside" // (i.e. using inner doc comments). pub mod analyze; +pub mod const_data; mod layout; pub mod lift; pub mod lower; diff --git a/src/spv/lift.rs b/src/spv/lift.rs index 2d6e7dce..658ebc43 100644 --- a/src/spv/lift.rs +++ b/src/spv/lift.rs @@ -8,8 +8,8 @@ use crate::{ ControlNodeKind, ControlNodeOutputDecl, ControlRegion, ControlRegionInputDecl, DataInst, DataInstDef, DataInstForm, DataInstFormDef, DataInstKind, DeclDef, EntityList, EntityOrientedDenseMap, ExportKey, Exportee, Func, FuncDecl, FuncDefBody, FuncParam, - FxIndexMap, FxIndexSet, GlobalVar, GlobalVarDefBody, Import, Module, ModuleDebugInfo, - ModuleDialect, SelectionKind, Type, TypeDef, TypeKind, TypeOrConst, Value, + FxIndexMap, FxIndexSet, GlobalVar, GlobalVarDefBody, GlobalVarInit, Import, Module, + ModuleDebugInfo, ModuleDialect, SelectionKind, Type, TypeDef, TypeKind, TypeOrConst, Value, }; use itertools::Itertools; use rustc_hash::FxHashMap; @@ -19,6 +19,7 @@ use std::collections::BTreeMap; use std::num::NonZeroUsize; use std::ops::Range; use std::path::Path; +use std::rc::Rc; use std::{io, iter, mem, slice}; // HACK(eddyb) getting around the lack of a `Step` impl on `spv::Id` (`NonZeroU32`). @@ -118,6 +119,9 @@ struct ModuleIds<'a> { globals: FxIndexMap, // FIXME(eddyb) use `EntityOrientedDenseMap` here. funcs: FxIndexMap>, + + // FIXME(eddyb) should this be somehow snuck into `globals`? + reaggregated_global_var_initializers: FxHashMap, } #[derive(Copy, Clone, PartialEq, Eq, Hash)] @@ -236,8 +240,34 @@ impl Visitor<'_> for Lifter<'_, AI> { } fn visit_global_var_use(&mut self, gv: GlobalVar) { - if self.global_vars_seen.insert(gv) { - self.visit_global_var_decl(&self.module.global_vars[gv]); + if !self.global_vars_seen.insert(gv) { + return; + } + let gv_decl = &self.module.global_vars[gv]; + self.visit_global_var_decl(gv_decl); + + match &gv_decl.def { + DeclDef::Imported(_) => {} + DeclDef::Present(gv_def_body) => match &gv_def_body.initializer { + None | Some(GlobalVarInit::Direct(_)) => {} + + // FIXME(eddyb) this should be a proper `Result`-based error instead, + // and/or `spv::lift` should mutate the module for legalization. + Some(GlobalVarInit::Composite { .. }) => { + unreachable!( + "`GlobalVarInit::Composite` should be legalized away before lifting" + ); + } + + // HACK(eddyb) recursively reconstruct an initializer as a tree + // of (otherwise illegal) `Const`s, with SPIR-V aggregate types. + // FIXME(eddyb) this *technically* pollutes the `Context`, but + // is easier than having two ways of tracking SPIR-V constants. + Some(GlobalVarInit::SpvAggregate { ty, leaves }) => { + let init = self.reaggregate_const(*ty, leaves); + self.ids.reaggregated_global_var_initializers.insert(gv, init); + } + }, } } fn visit_func_use(&mut self, func: Func) { @@ -345,6 +375,45 @@ impl Visitor<'_> for Lifter<'_, AI> { } } +impl Lifter<'_, AI> { + // FIXME(eddyb) maybe use this for `DataInstDef` inputs as well, when `Const`s, + // not just `GlobalVarInit::SpvAggregate`? + fn reaggregate_const(&mut self, ty: Type, leaves: &[Const]) -> Const { + let ty_def = &self.cx[ty]; + assert_eq!(leaves.len(), ty_def.disaggregated_leaf_count()); + + if let spv::ValueLowering::Direct = ty_def.spv_value_lowering() { + let &[ct] = leaves.try_into().unwrap(); + return ct; + } + + // HACK(eddyb) this is a bit inefficient but increases code reuse, in + // a case that'd otherwise require e.g. an `Iterator` w/ `nth` overload. + let mut used_leaves = 0..0; + let components = (0..) + .map_while(|i| ty.aggregate_component_type_and_leaf_range(self.cx, i)) + .map(|(component_type, component_leaf_range)| { + assert_eq!(used_leaves.end, component_leaf_range.start); + used_leaves.end = component_leaf_range.end; + self.reaggregate_const(component_type, &leaves[component_leaf_range]) + }) + .collect(); + assert_eq!(used_leaves, 0..leaves.len()); + + let wk = &spec::Spec::get().well_known; + let ct = self.cx.intern(ConstDef { + attrs: AttrSet::default(), + ty, + kind: ConstKind::SpvInst { + spv_inst_and_const_inputs: Rc::new((wk.OpConstantComposite.into(), components)), + }, + }); + // HACK(eddyb) visit constants as they're created, to ensure they're recorded. + self.visit_const_use(ct); + ct + } +} + // FIXME(eddyb) this is inconsistently named with `FuncIds`. struct FuncBodyLifting<'a> { region_inputs_source: EntityOrientedDenseMap, @@ -1410,9 +1479,19 @@ impl LazyInst<'_, '_> { spv::Imm::Short(wk.StorageClass, sc) } }; - let initializer = match gv_decl.def { + let initializer = match &gv_decl.def { DeclDef::Imported(_) => None, DeclDef::Present(GlobalVarDefBody { initializer }) => initializer + .as_ref() + .map(|initializer| match initializer { + // Disallowed while visiting. + GlobalVarInit::Composite { .. } => unreachable!(), + + &GlobalVarInit::Direct(ct) => ct, + GlobalVarInit::SpvAggregate { .. } => { + ids.reaggregated_global_var_initializers[&gv] + } + }) .map(|initializer| ids.globals[&Global::Const(initializer)]), }; spv::InstWithIds { diff --git a/src/spv/lower.rs b/src/spv/lower.rs index 6d2c8575..5484df90 100644 --- a/src/spv/lower.rs +++ b/src/spv/lower.rs @@ -7,8 +7,8 @@ use crate::{ ControlNodeDef, ControlNodeKind, ControlRegion, ControlRegionDef, ControlRegionInputDecl, DataInst, DataInstDef, DataInstFormDef, DataInstKind, DeclDef, Diag, EntityDefs, EntityList, ExportKey, Exportee, Func, FuncDecl, FuncDefBody, FuncParam, FxIndexMap, GlobalVarDecl, - GlobalVarDefBody, Import, InternedStr, Module, SelectionKind, Type, TypeDef, TypeKind, - TypeOrConst, Value, + GlobalVarDefBody, GlobalVarInit, Import, InternedStr, Module, SelectionKind, Type, TypeDef, + TypeKind, TypeOrConst, Value, }; use itertools::Either; use rustc_hash::FxHashMap; @@ -29,12 +29,9 @@ enum IdDef { /// constants (e.g. `OpConstantComposite`s of those types, but also more /// general constants like `OpUndef`/`OpConstantNull` etc.). AggregateConst { - // FIXME(eddyb) remove `whole_const` by always using the `leaves`. - whole_const: Const, - whole_type: Type, - leaves: SmallVec<[Const; 2]>, + leaves: SmallVec<[Const; 4]>, }, Func(Func), @@ -64,67 +61,20 @@ impl IdDef { } impl Type { - fn aggregate_component_leaf_range_and_type( - self, - cx: &Context, - idx: u32, - ) -> Option<(Range, Type)> { - let (type_and_const_inputs, aggregate_shape) = match &cx[self].kind { - TypeKind::SpvInst { - spv_inst: _, - type_and_const_inputs, - value_lowering: spv::ValueLowering::Disaggregate(aggregate_shape), - } => (type_and_const_inputs, aggregate_shape), - _ => return None, - }; - let expect_type = |ty_or_ct| match ty_or_ct { - TypeOrConst::Type(ty) => ty, - TypeOrConst::Const(_) => unreachable!(), - }; - - let idx_usize = idx as usize; - let component_type = match aggregate_shape { - spv::AggregateShape::Struct { .. } => { - expect_type(*type_and_const_inputs.get(idx_usize)?) - } - &spv::AggregateShape::Array { fixed_len, .. } => { - if idx >= fixed_len { - return None; - } - expect_type(type_and_const_inputs[0]) - } - }; - let component_leaf_count = cx[component_type].disaggregated_leaf_count(); - - let component_leaf_range = match aggregate_shape { - spv::AggregateShape::Struct { per_field_leaf_range_end } => { - let end = per_field_leaf_range_end[idx_usize] as usize; - let start = end.checked_sub(component_leaf_count)?; - start..end - } - spv::AggregateShape::Array { .. } => { - let start = component_leaf_count.checked_mul(idx_usize)?; - let end = start.checked_add(component_leaf_count)?; - start..end - } - }; - Some((component_leaf_range, component_type)) - } - // HACK(eddyb) `indices` is a `&mut` because it specifically only consumes // the indices it needs, so when this function returns `Some`, all remaining // indices will be left over for the caller to process itself. - fn aggregate_component_path_leaf_range_and_type( + fn aggregate_component_path_type_and_leaf_range( self, cx: &Context, indices: &mut impl Iterator, - ) -> Option<(Range, Type)> { - let (mut leaf_range, mut leaf_type) = - self.aggregate_component_leaf_range_and_type(cx, indices.next()?)?; + ) -> Option<(Type, Range)> { + let (mut leaf_type, mut leaf_range) = + self.aggregate_component_type_and_leaf_range(cx, indices.next()?)?; while let spv::ValueLowering::Disaggregate(_) = cx[leaf_type].spv_value_lowering() { - let (sub_leaf_range, sub_leaf_type) = match indices.next() { - Some(i) => leaf_type.aggregate_component_leaf_range_and_type(cx, i)?, + let (sub_leaf_type, sub_leaf_range) = match indices.next() { + Some(i) => leaf_type.aggregate_component_type_and_leaf_range(cx, i)?, None => break, }; @@ -134,7 +84,7 @@ impl Type { leaf_type = sub_leaf_type; } - Some((leaf_range, leaf_type)) + Some((leaf_type, leaf_range)) } } @@ -704,8 +654,67 @@ impl Module { let ty = result_type.unwrap(); - let mut aggregate_leaves = match cx[ty].spv_value_lowering() { - spv::ValueLowering::Direct => None, + // HACK(eddyb) while creating constants of unsized array types + // is *technically* illegal in SPIR-V, array semantics always + // are length-independent, so we can pretend this is an array + // of the right length (as long as we track the error on it). + let maybe_fixup_unsized_array_type = |ty: Type| { + if ![wk.OpConstantComposite, wk.OpSpecConstantComposite].contains(&opcode) { + return None; + }; + let actual_component_count = u32::try_from(inst.ids.len()).ok()?; + + let ty_def = &cx[ty]; + let elem_type_of_unsized_array = match &ty_def.kind { + TypeKind::SpvInst { spv_inst: ty_inst, type_and_const_inputs, .. } => { + match type_and_const_inputs[..] { + [TypeOrConst::Type(elem_type), TypeOrConst::Const(len)] + if ty_inst.opcode == wk.OpTypeArray + && len.as_scalar(&cx).is_none() => + { + elem_type + } + [TypeOrConst::Type(elem_type)] + if ty_inst.opcode == wk.OpTypeRuntimeArray => + { + elem_type + } + _ => return None, + } + } + _ => return None, + }; + Some( + cx.intern(TypeDef { + attrs: ty_def.attrs.append_diag( + &cx, + Diag::err([ + "illegal constant: values of type `".into(), + ty.into(), + "` should only be accessed through pointers".into(), + ]), + ), + kind: spv::Inst::from(wk.OpTypeArray).into_canonical_type_with( + &cx, + [ + TypeOrConst::Type(elem_type_of_unsized_array), + TypeOrConst::Const( + cx.intern(scalar::Const::from_u32(actual_component_count)), + ), + ] + .into_iter() + .collect(), + ), + }), + ) + }; + let ty = maybe_fixup_unsized_array_type(ty).unwrap_or(ty); + + let mut all_leaves = SmallVec::new(); + match cx[ty].spv_value_lowering() { + spv::ValueLowering::Direct => { + all_leaves.reserve(inst.ids.len()); + } spv::ValueLowering::Disaggregate(_) => { // HACK(eddyb) this expands `OpUndef`/`OpConstantNull`. // FIXME(eddyb) this could potentially create a very @@ -713,101 +722,105 @@ impl Module { // be expressed much more compactly in theory. if inst.lower_const_by_distributing_to_aggregate_leaves() { assert_eq!(inst.ids.len(), 0); - Some( - ty.disaggregated_leaf_types(&cx) - .map(|leaf_type| { - cx.intern(ConstDef { - attrs: Default::default(), - ty: leaf_type, - kind: inst - .as_canonical_const(&cx, leaf_type, &[]) - .unwrap_or_else(|| ConstKind::SpvInst { - spv_inst_and_const_inputs: Rc::new(( - inst.without_ids.clone(), - [].into_iter().collect(), - )), - }), - }) - }) - .collect(), - ) + all_leaves.extend(ty.disaggregated_leaf_types(&cx).map(|leaf_type| { + cx.intern(ConstDef { + attrs: Default::default(), + ty: leaf_type, + kind: inst + .as_canonical_const(&cx, leaf_type, &[]) + .unwrap_or_else(|| ConstKind::SpvInst { + spv_inst_and_const_inputs: Rc::new(( + inst.without_ids.clone(), + [].into_iter().collect(), + )), + }), + }) + })); } else if [wk.OpConstantComposite, wk.OpSpecConstantComposite] .contains(&opcode) { - // NOTE(eddyb) actual leaves gathered below, while - // collecting `const_inputs`. - Some(SmallVec::with_capacity(cx[ty].disaggregated_leaf_count())) + all_leaves.reserve(cx[ty].disaggregated_leaf_count()); } else { attrs.push_diag( &cx, Diag::bug(["unsupported aggregate-producing constant".into()]), ); - None } } - }; + } - let const_inputs: SmallVec<_> = inst - .ids - .iter() - .map(|&id| match id_defs.get(&id) { + let invalid = |descr| invalid(&format!("unsupported use of {descr} in a constant")); + for &id in &inst.ids { + match id_defs.get(&id) { Some(&IdDef::Const(ct)) => { - if let Some(aggregate_leaves) = &mut aggregate_leaves { - aggregate_leaves.push(ct); - } - Ok(ct) + all_leaves.push(ct); } - Some(IdDef::AggregateConst { whole_const, whole_type: _, leaves }) => { - if let Some(aggregate_leaves) = &mut aggregate_leaves { - aggregate_leaves.extend(leaves.iter().copied()); + Some(IdDef::AggregateConst { whole_type, leaves }) => { + all_leaves.extend(leaves.iter().copied()); + + match cx[ty].spv_value_lowering() { + // FIXME(eddyb) this also covers invalid consts + // of e.g. unsized aggregate types, as well. + spv::ValueLowering::Direct => { + attrs.push_diag( + &cx, + Diag::err([ + "unexpected aggregate constant of type `".into(), + (*whole_type).into(), + "`".into(), + ]), + ); + } + spv::ValueLowering::Disaggregate(_) => {} } - Ok(*whole_const) } - Some(id_def) => Err(id_def.descr(&cx)), - None => Err(format!("a forward reference to %{id}")), - }) - .map(|result| { - result.map_err(|descr| { - invalid(&format!("unsupported use of {descr} in a constant")) - }) - }) - .collect::>()?; + Some(id_def) => return Err(invalid(&id_def.descr(&cx))), + None => return Err(invalid(&format!("a forward reference to %{id}"))), + } + } - if let (spv::ValueLowering::Disaggregate(_), Some(leaves)) = - (cx[ty].spv_value_lowering(), &aggregate_leaves) - { - if cx[ty].disaggregated_leaf_count() != leaves.len() { + let lowering = &cx[ty].spv_value_lowering(); + let lowering = match lowering { + spv::ValueLowering::Disaggregate(_) + if cx[ty].disaggregated_leaf_count() != all_leaves.len() => + { attrs.push_diag( &cx, Diag::err([format!( "aggregate leaf count mismatch (expected {}, found {})", cx[ty].disaggregated_leaf_count(), - leaves.len() + all_leaves.len() ) .into()]), ); - aggregate_leaves = None; + // HACK(eddyb) pretend the type isn't an aggregate, so + // that it doesn't end up using `IdDef::AggregateConst`, + // which requires having the exact number of leaves. + &spv::ValueLowering::Direct } - } + _ => lowering, + }; - let ct = cx.intern(ConstDef { - attrs: mem::take(&mut attrs), - ty, - kind: inst.as_canonical_const(&cx, ty, &const_inputs).unwrap_or_else(|| { - ConstKind::SpvInst { - spv_inst_and_const_inputs: Rc::new((inst.without_ids, const_inputs)), - } - }), - }); + let attrs = mem::take(&mut attrs); id_defs.insert( id, - match (cx[ty].spv_value_lowering(), aggregate_leaves) { - (spv::ValueLowering::Disaggregate(_), Some(leaves)) => { - // FIXME(eddyb) this may lose semantic `attrs` when - // `leaves` are directly used. - IdDef::AggregateConst { whole_const: ct, whole_type: ty, leaves } + match lowering { + spv::ValueLowering::Direct => IdDef::Const(cx.intern(ConstDef { + attrs, + ty, + kind: inst.as_canonical_const(&cx, ty, &all_leaves).unwrap_or_else( + || ConstKind::SpvInst { + spv_inst_and_const_inputs: Rc::new(( + inst.without_ids, + all_leaves, + )), + }, + ), + })), + spv::ValueLowering::Disaggregate(_) => { + // FIXME(eddyb) this may lose semantic `attrs`. + IdDef::AggregateConst { whole_type: ty, leaves: all_leaves } } - _ => IdDef::Const(ct), }, ); @@ -841,10 +854,12 @@ impl Module { let initializer = initializer .map(|id| match id_defs.get(&id) { - Some(&IdDef::Const(ct)) => Ok(ct), - Some(&IdDef::AggregateConst { whole_const, .. }) => { - // FIXME(eddyb) disaggregate global initializers. - Ok(whole_const) + Some(&IdDef::Const(ct)) => Ok(GlobalVarInit::Direct(ct)), + Some(IdDef::AggregateConst { whole_type, leaves }) => { + Ok(GlobalVarInit::SpvAggregate { + ty: *whole_type, + leaves: leaves.clone(), + }) } Some(id_def) => Err(id_def.descr(&cx)), None => Err(format!("a forward reference to %{id}")), @@ -1257,14 +1272,12 @@ impl Module { whole_type: cx[ct].ty, leaves: Either::Right(Either::Left([Value::Const(ct)].into_iter())), }), - Some(IdDef::AggregateConst { whole_const: _, whole_type, leaves }) => { - Ok(LocalIdDef::Value { - whole_type: *whole_type, - leaves: Either::Right(Either::Right( - leaves.iter().copied().map(Value::Const), - )), - }) - } + Some(IdDef::AggregateConst { whole_type, leaves }) => Ok(LocalIdDef::Value { + whole_type: *whole_type, + leaves: Either::Right(Either::Right( + leaves.iter().copied().map(Value::Const), + )), + }), Some(id_def @ IdDef::Type(_)) => Err(invalid(&format!( "unsupported use of {} as an operand for \ an instruction in a function", @@ -1758,11 +1771,11 @@ impl Module { } let mut imms = imms.iter(); - let (leaf_range, leaf_type) = match cx[composite_type].spv_value_lowering() + let (leaf_type, leaf_range) = match cx[composite_type].spv_value_lowering() { spv::ValueLowering::Direct => return None, spv::ValueLowering::Disaggregate(_) => composite_type - .aggregate_component_path_leaf_range_and_type( + .aggregate_component_path_type_and_leaf_range( &cx, &mut imms.by_ref().map(|&imm| match imm { spv::Imm::Short(_, i) => i, diff --git a/src/spv/mod.rs b/src/spv/mod.rs index 2b0156c5..de35bb7b 100644 --- a/src/spv/mod.rs +++ b/src/spv/mod.rs @@ -302,6 +302,51 @@ impl Type { parent_component_path: SmallVec::new(), })) } + + fn aggregate_component_type_and_leaf_range( + self, + cx: &Context, + idx: u32, + ) -> Option<(Type, Range)> { + let (type_and_const_inputs, aggregate_shape) = match &cx[self].kind { + TypeKind::SpvInst { + spv_inst: _, + type_and_const_inputs, + value_lowering: ValueLowering::Disaggregate(aggregate_shape), + } => (type_and_const_inputs, aggregate_shape), + _ => return None, + }; + let expect_type = |ty_or_ct| match ty_or_ct { + TypeOrConst::Type(ty) => ty, + TypeOrConst::Const(_) => unreachable!(), + }; + + let idx_usize = idx as usize; + let component_type = match aggregate_shape { + AggregateShape::Struct { .. } => expect_type(*type_and_const_inputs.get(idx_usize)?), + &AggregateShape::Array { fixed_len, .. } => { + if idx >= fixed_len { + return None; + } + expect_type(type_and_const_inputs[0]) + } + }; + let component_leaf_count = cx[component_type].disaggregated_leaf_count(); + + let component_leaf_range = match aggregate_shape { + AggregateShape::Struct { per_field_leaf_range_end } => { + let end = per_field_leaf_range_end[idx_usize] as usize; + let start = end.checked_sub(component_leaf_count)?; + start..end + } + AggregateShape::Array { .. } => { + let start = component_leaf_count.checked_mul(idx_usize)?; + let end = start.checked_add(component_leaf_count)?; + start..end + } + }; + Some((component_type, component_leaf_range)) + } } /// Aspects of how a [`spv::Inst`](Inst) was produced by [`spv::lower`](lower), diff --git a/src/transform.rs b/src/transform.rs index 1b5b0a52..37debc65 100644 --- a/src/transform.rs +++ b/src/transform.rs @@ -7,8 +7,8 @@ use crate::{ ControlNodeDef, ControlNodeKind, ControlNodeOutputDecl, ControlRegion, ControlRegionDef, ControlRegionInputDecl, DataInst, DataInstDef, DataInstForm, DataInstFormDef, DataInstKind, DeclDef, EntityListIter, ExportKey, Exportee, Func, FuncDecl, FuncDefBody, FuncParam, - GlobalVar, GlobalVarDecl, GlobalVarDefBody, Import, Module, ModuleDebugInfo, ModuleDialect, - OrdAssertEq, SelectionKind, Type, TypeDef, TypeKind, TypeOrConst, Value, + GlobalVar, GlobalVarDecl, GlobalVarDefBody, GlobalVarInit, Import, Module, ModuleDebugInfo, + ModuleDialect, OrdAssertEq, SelectionKind, Type, TypeDef, TypeKind, TypeOrConst, Value, }; use std::cmp::Ordering; use std::rc::Rc; @@ -526,7 +526,26 @@ impl InnerInPlaceTransform for GlobalVarDefBody { let Self { initializer } = self; if let Some(initializer) = initializer { - transformer.transform_const_use(*initializer).apply_to(initializer); + initializer.inner_in_place_transform_with(transformer); + } + } +} + +impl InnerInPlaceTransform for GlobalVarInit { + fn inner_in_place_transform_with(&mut self, transformer: &mut impl Transformer) { + match self { + GlobalVarInit::Direct(ct) => transformer.transform_const_use(*ct).apply_to(ct), + GlobalVarInit::SpvAggregate { ty, leaves } => { + transformer.transform_type_use(*ty).apply_to(ty); + for ct in leaves { + transformer.transform_const_use(*ct).apply_to(ct); + } + } + GlobalVarInit::Composite { offset_to_value } => { + for ct in offset_to_value.values_mut() { + transformer.transform_const_use(*ct).apply_to(ct); + } + } } } } diff --git a/src/visit.rs b/src/visit.rs index d75b00c5..0faa6ebc 100644 --- a/src/visit.rs +++ b/src/visit.rs @@ -7,8 +7,8 @@ use crate::{ ControlNodeDef, ControlNodeKind, ControlNodeOutputDecl, ControlRegion, ControlRegionDef, ControlRegionInputDecl, DataInstDef, DataInstForm, DataInstFormDef, DataInstKind, DeclDef, DiagMsgPart, EntityListIter, ExportKey, Exportee, Func, FuncDecl, FuncDefBody, FuncParam, - GlobalVar, GlobalVarDecl, GlobalVarDefBody, Import, Module, ModuleDebugInfo, ModuleDialect, - SelectionKind, Type, TypeDef, TypeKind, TypeOrConst, Value, + GlobalVar, GlobalVarDecl, GlobalVarDefBody, GlobalVarInit, Import, Module, ModuleDebugInfo, + ModuleDialect, SelectionKind, Type, TypeDef, TypeKind, TypeOrConst, Value, }; // FIXME(eddyb) `Sized` bound shouldn't be needed but removing it requires @@ -388,8 +388,27 @@ impl InnerVisit for GlobalVarDefBody { fn inner_visit_with<'a>(&'a self, visitor: &mut impl Visitor<'a>) { let Self { initializer } = self; - if let Some(initializer) = *initializer { - visitor.visit_const_use(initializer); + if let Some(initializer) = initializer { + initializer.inner_visit_with(visitor); + } + } +} + +impl InnerVisit for GlobalVarInit { + fn inner_visit_with<'a>(&'a self, visitor: &mut impl Visitor<'a>) { + match self { + &GlobalVarInit::Direct(ct) => visitor.visit_const_use(ct), + GlobalVarInit::SpvAggregate { ty, leaves } => { + visitor.visit_type_use(*ty); + for &ct in leaves { + visitor.visit_const_use(ct); + } + } + GlobalVarInit::Composite { offset_to_value } => { + for &ct in offset_to_value.values() { + visitor.visit_const_use(ct); + } + } } } } From d5824f107a40d290a28deeae5f0fb287e3fe45d9 Mon Sep 17 00:00:00 2001 From: Eduard-Mihai Burtescu Date: Sun, 1 Oct 2023 17:12:52 +0300 Subject: [PATCH 19/22] qptr/lower: minimally support opaque handle `OpVariable`s. --- src/qptr/lower.rs | 48 +++++++++++++++++++++++++++++++++++++---------- src/spv/spec.rs | 4 ++++ 2 files changed, 42 insertions(+), 10 deletions(-) diff --git a/src/qptr/lower.rs b/src/qptr/lower.rs index 7c674da2..40952069 100644 --- a/src/qptr/lower.rs +++ b/src/qptr/lower.rs @@ -941,7 +941,7 @@ impl LowerFromSpvPtrInstsInFunc<'_> { // HACK(eddyb) this is for `aggregate_load_to_leaf_loads`, // which gets used later, to replace uses of one of the - // outputs ofthe original `OpLoad`, with uses of leaf loads. + // outputs of the original `OpLoad`, with uses of leaf loads. let mut leaf_loads = if spv_inst.opcode == wk.OpLoad { Some(SmallVec::with_capacity(leaf_forms_and_extra_inputs.len())) } else { @@ -1029,15 +1029,18 @@ impl LowerFromSpvPtrInstsInFunc<'_> { // FIXME(eddyb) this may need to automatically generate an // intermediary `QPtrOp::BufferData` when accessing buffers. - let mem_data_layout = match self.lowerer.layout_of(src_pointee_type)? { - TypeLayout::Concrete(mem) => mem, - _ => { - return Err(LowerError(Diag::bug([ - "`OpCopyMemory` of data with non-memory type: ".into(), - src_pointee_type.into(), - ]))); - } - }; + let mem_data_layout_or_opaque_handle_type = + match self.lowerer.layout_of(src_pointee_type)? { + TypeLayout::Concrete(mem) => Ok(mem), + // HACK(eddyb) Rust-GPU generates `OpCopyMemory`s of handles. + TypeLayout::Handle(shapes::Handle::Opaque(ty)) => Err(ty), + _ => { + return Err(LowerError(Diag::bug([ + "`OpCopyMemory` of data with non-memory type: ".into(), + src_pointee_type.into(), + ]))); + } + }; let (dst_ptr, dst_base_offset) = flatten_offsets(dst_ptr); let (src_ptr, src_base_offset) = flatten_offsets(src_ptr); @@ -1089,6 +1092,13 @@ impl LowerFromSpvPtrInstsInFunc<'_> { // be generating, because we don't know ahead of time whether we // even want to expand the `OpCopyMemory`, at all. let mut leaf_offsets_and_types = SmallVec::<[_; 8]>::new(); + let mem_data_layout = match mem_data_layout_or_opaque_handle_type { + Ok(mem_data_layout) => mem_data_layout, + Err(opaque_handle_type) => { + leaf_offsets_and_types.push((0, opaque_handle_type)); + return Some(leaf_offsets_and_types); + } + }; mem_data_layout .deeply_flatten_if( 0, @@ -1231,6 +1241,7 @@ impl LowerFromSpvPtrInstsInFunc<'_> { mut func_at_data_inst: FuncAtMut<'_, DataInst>, extra_error: Option, ) { + let wk = self.lowerer.wk; let cx = &self.lowerer.cx; let func_at_data_inst_frozen = func_at_data_inst.reborrow().freeze(); @@ -1282,6 +1293,23 @@ impl LowerFromSpvPtrInstsInFunc<'_> { continue; } + // HACK(eddyb) avoid otherwise-unsupported instructions ending up + // with invalid address spaces (such cases should be errors, + // but Rust-GPU still emits `Generic` everywhere, and having + // def-vs-use type mismatches, instead, would also cause issues). + let addr_space = match &data_inst_form_def.kind { + DataInstKind::SpvInst(spv_inst, _) => { + if spv_inst.opcode == wk.OpVariable { + AddrSpace::SpvStorageClass(wk.Function) + } else if spv_inst.opcode == wk.OpImageTexelPointer { + AddrSpace::SpvStorageClass(wk.Image) + } else { + addr_space + } + } + _ => addr_space, + }; + old_and_new_attrs.get_or_insert_with(get_old_attrs).attrs.insert( QPtrAttr::FromSpvPtrOutput { addr_space: OrdAssertEq(addr_space), diff --git a/src/spv/spec.rs b/src/spv/spec.rs index d5dc8f9b..40c59dec 100644 --- a/src/spv/spec.rs +++ b/src/spv/spec.rs @@ -153,6 +153,8 @@ def_well_known! { OpFunctionCall, + OpImageTexelPointer, + OpLoad, OpStore, OpCopyMemory, @@ -197,6 +199,8 @@ def_well_known! { Input, Output, + Image, + IncomingRayPayloadKHR, IncomingCallableDataKHR, HitAttributeKHR, From 9ed8a3b6efce4682e83d8c40a95ec0a34bc2beba Mon Sep 17 00:00:00 2001 From: Eduard-Mihai Burtescu Date: Sun, 1 Oct 2023 17:13:22 +0300 Subject: [PATCH 20/22] qptr/layout: allow `qptr`s to have a memory layout (and thus be loaded/stored). --- src/qptr/layout.rs | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/qptr/layout.rs b/src/qptr/layout.rs index 559c9bae..1c08d6dc 100644 --- a/src/qptr/layout.rs +++ b/src/qptr/layout.rs @@ -277,8 +277,6 @@ impl<'a> LayoutCache<'a> { let cx = &self.cx; let wk = self.wk; - let ty_def = &cx[ty]; - let scalar_with_size_and_align = |(size, align)| { TypeLayout::Concrete(Rc::new(MemTypeLayout { original_type: ty, @@ -289,6 +287,7 @@ impl<'a> LayoutCache<'a> { components: Components::Scalar, })) }; + let scalar = |width: u32| { assert!(width.is_power_of_two()); let size = width / 8; @@ -402,6 +401,7 @@ impl<'a> LayoutCache<'a> { // ugh this doesn't make any sense. maybe if the front-end specifies // offsets with "abstract types", it must configure `qptr::layout`? + let ty_def = &cx[ty]; let (spv_inst, type_and_const_inputs) = match &ty_def.kind { TypeKind::Scalar(scalar::Type::Bool) => { // FIXME(eddyb) make this properly abstract instead of only configurable. @@ -424,12 +424,15 @@ impl<'a> LayoutCache<'a> { ); } - // FIXME(eddyb) treat `QPtr`s as scalars. TypeKind::QPtr => { - return Err(LayoutError(Diag::bug( - ["`layout_of(qptr)` (already lowered?)".into()], - ))); + // FIXME(eddyb) make this properly abstract instead of only configurable. + // FIXME(eddyb) avoid logical vs physical conflicts here, maybe + // by adding more `LayoutConfig` fields, or only allowing `qptr`s + // to be kept in memory if `logical_ptr_size_align` agrees with + // the physical pointer size from the module addressing mode? + return Ok(scalar_with_size_and_align(self.config.logical_ptr_size_align)); } + TypeKind::SpvInst { spv_inst, type_and_const_inputs, .. } => { (spv_inst, type_and_const_inputs) } From a8506cc97495b47199b19117f62d027ae679269c Mon Sep 17 00:00:00 2001 From: Eduard-Mihai Burtescu Date: Sun, 1 Oct 2023 17:13:51 +0300 Subject: [PATCH 21/22] [WIP] qptr/simplify: add "partition" and "propagate" passes for function-local variables. --- examples/spv-lower-link-qptr-lift.rs | 6 + src/passes/qptr.rs | 49 +- src/qptr/analyze.rs | 4 +- src/qptr/mod.rs | 1 + src/qptr/simplify.rs | 1063 ++++++++++++++++++++++++++ 5 files changed, 1120 insertions(+), 3 deletions(-) create mode 100644 src/qptr/simplify.rs diff --git a/examples/spv-lower-link-qptr-lift.rs b/examples/spv-lower-link-qptr-lift.rs index d4b788d0..2328b6b1 100644 --- a/examples/spv-lower-link-qptr-lift.rs +++ b/examples/spv-lower-link-qptr-lift.rs @@ -92,6 +92,12 @@ fn main() -> std::io::Result<()> { eprintln!("qptr::lower_from_spv_ptrs"); after_pass("qptr::lower_from_spv_ptrs", &module)?; + eprint_duration(|| { + spirt::passes::qptr::partition_and_propagate(&mut module, layout_config) + }); + eprintln!("qptr::partition_and_propagate"); + after_pass("qptr::partition_and_propagate", &module)?; + eprint_duration(|| spirt::passes::qptr::analyze_uses(&mut module, layout_config)); eprintln!("qptr::analyze_uses"); after_pass("qptr::analyze_uses", &module)?; diff --git a/src/passes/qptr.rs b/src/passes/qptr.rs index c4f49a4b..96284365 100644 --- a/src/passes/qptr.rs +++ b/src/passes/qptr.rs @@ -1,7 +1,7 @@ //! [`QPtr`](crate::TypeKind::QPtr) transforms. use crate::visit::{InnerVisit, Visitor}; -use crate::{qptr, DataInstForm}; +use crate::{qptr, DataInstForm, DeclDef}; use crate::{AttrSet, Const, Context, Func, FxIndexSet, GlobalVar, Module, Type}; pub fn lower_from_spv_ptrs(module: &mut Module, layout_config: &qptr::LayoutConfig) { @@ -35,6 +35,53 @@ pub fn lower_from_spv_ptrs(module: &mut Module, layout_config: &qptr::LayoutConf } } +// FIXME(eddyb) split this into separate passes, but the looping complicates things. +pub fn partition_and_propagate(module: &mut Module, layout_config: &qptr::LayoutConfig) { + let cx = &module.cx(); + + let (_seen_global_vars, seen_funcs) = { + // FIXME(eddyb) reuse this collection work in some kind of "pass manager". + let mut collector = ReachableUseCollector { + cx, + module, + + seen_types: FxIndexSet::default(), + seen_consts: FxIndexSet::default(), + seen_data_inst_forms: FxIndexSet::default(), + seen_global_vars: FxIndexSet::default(), + seen_funcs: FxIndexSet::default(), + }; + for (export_key, &exportee) in &module.exports { + export_key.inner_visit_with(&mut collector); + exportee.inner_visit_with(&mut collector); + } + (collector.seen_global_vars, collector.seen_funcs) + }; + + for func in seen_funcs { + if let DeclDef::Present(func_def_body) = &mut module.funcs[func].def { + // FIXME(eddyb) reuse `LayoutCache` and whatnot, between functions, + // or at least iterations of this loop. + loop { + qptr::simplify::partition_local_vars_in_func( + cx.clone(), + layout_config, + func_def_body, + ); + + let report = qptr::simplify::propagate_contents_of_local_vars_in_func( + cx.clone(), + layout_config, + func_def_body, + ); + if !report.any_qptrs_propagated { + break; + } + } + } + } +} + pub fn analyze_uses(module: &mut Module, layout_config: &qptr::LayoutConfig) { qptr::analyze::InferUsage::new(module.cx(), layout_config).infer_usage_in_module(module); } diff --git a/src/qptr/analyze.rs b/src/qptr/analyze.rs index d95517a3..140620fa 100644 --- a/src/qptr/analyze.rs +++ b/src/qptr/analyze.rs @@ -1,9 +1,9 @@ //! [`QPtr`](crate::TypeKind::QPtr) usage analysis (for legalizing/lifting). // HACK(eddyb) sharing layout code with other modules. -use super::{layout::*, QPtrMemUsageKind}; +use super::layout::*; -use super::{shapes, QPtrAttr, QPtrMemUsage, QPtrOp, QPtrUsage}; +use super::{shapes, QPtrAttr, QPtrMemUsage, QPtrMemUsageKind, QPtrOp, QPtrUsage}; use crate::func_at::FuncAt; use crate::visit::{InnerVisit, Visitor}; use crate::{ diff --git a/src/qptr/mod.rs b/src/qptr/mod.rs index 513bcb47..4f479294 100644 --- a/src/qptr/mod.rs +++ b/src/qptr/mod.rs @@ -18,6 +18,7 @@ mod layout; pub mod lift; pub mod lower; pub mod shapes; +pub mod simplify; pub use layout::LayoutConfig; diff --git a/src/qptr/simplify.rs b/src/qptr/simplify.rs new file mode 100644 index 00000000..7934d7e3 --- /dev/null +++ b/src/qptr/simplify.rs @@ -0,0 +1,1063 @@ +//! [`QPtr`](crate::TypeKind::QPtr) simplification passes. + +// HACK(eddyb) sharing layout code with other modules. +use super::layout::*; + +use super::{shapes, QPtrOp}; +use crate::func_at::{FuncAt, FuncAtMut}; +use crate::visit::{InnerVisit, Visitor}; +use crate::{ + vector, AttrSet, Const, ConstDef, ConstKind, Context, ControlNodeOutputDecl, ControlRegion, + ControlRegionDef, ControlRegionInputDecl, DataInst, DataInstDef, DataInstForm, DataInstKind, + Func, FuncDefBody, FxIndexMap, FxIndexSet, GlobalVar, Type, TypeKind, +}; +use crate::{ControlNode, Value}; +use crate::{ControlNodeKind, DataInstFormDef}; +use smallvec::SmallVec; +use std::collections::BTreeMap; +use std::num::NonZeroU32; +use std::ops::{Bound, Range}; +use std::rc::Rc; +use std::{mem, slice}; + +/// Split all function-local variables in `func_def_body` into independent ones. +// +// FIXME(eddyb) reduce the cost of creating all the per-partition local variables +// by feeding partitions to `propagate_contents_of_local_vars_in_func` directly. +pub fn partition_local_vars_in_func( + cx: Rc, + config: &LayoutConfig, + func_def_body: &mut FuncDefBody, +) { + let vars = { + let mut collector = CollectLocalVarPartitions { + cx: cx.clone(), + layout_cache: LayoutCache::new(cx.clone(), config), + vars: FxIndexMap::default(), + }; + func_def_body.inner_visit_with(&mut collector); + collector.vars + }; + + let qptr_type = cx.intern(TypeKind::QPtr); + + // Create new variables for all partitions, and replace their respective uses. + for (original_var_inst, var) in vars { + let original_var_qptr = Value::DataInstOutput { inst: original_var_inst, output_idx: 0 }; + + // Also shrink the original variable, if necessary. + if var.zero_offset_partition_size < var.original_layout.size { + func_def_body.at_mut(original_var_inst).def().form = cx.intern(DataInstFormDef { + kind: QPtrOp::FuncLocalVar(shapes::MemLayout { + size: var.zero_offset_partition_size, + ..var.original_layout + }) + .into(), + output_types: [qptr_type].into_iter().collect(), + }); + } + + for (partition_offset, partition) in var.non_zero_offset_to_partition { + let align_for_offset = 1 << partition_offset.trailing_zeros(); + + let partition_var_inst = func_def_body.data_insts.define( + &cx, + DataInstDef { + // FIXME(eddyb) preserve at least debuginfo attrs. + attrs: Default::default(), + form: cx.intern(DataInstFormDef { + kind: QPtrOp::FuncLocalVar(shapes::MemLayout { + align: var.original_layout.align.min(align_for_offset), + legacy_align: var.original_layout.legacy_align.min(align_for_offset), + size: partition.size, + }) + .into(), + output_types: [qptr_type].into_iter().collect(), + }), + inputs: Default::default(), + } + .into(), + ); + + match &mut func_def_body.control_nodes[var.parent_block].kind { + ControlNodeKind::Block { insts } => { + // FIXME(eddyb) this could use an `insert_after`, to avoid + // having all the partitions end up before the original. + insts.insert_before( + partition_var_inst, + original_var_inst, + &mut func_def_body.data_insts, + ); + } + _ => unreachable!(), + } + + let partition_var_qptr = + Value::DataInstOutput { inst: partition_var_inst, output_idx: 0 }; + + // FIMXE(eddyb) when `QPtrOp::Offset` ends up with a `0` offset, + // some further simplifications are possible, but it's not that + // relevant for now, as we're mainly interested in loads/stores. + for use_inst in partition.uses { + let data_inst_def = func_def_body.at_mut(use_inst).def(); + + assert!( + mem::replace(&mut data_inst_def.inputs[0], partition_var_qptr) + == original_var_qptr + ); + + let mut data_inst_form_def = cx[data_inst_def.form].clone(); + match &mut data_inst_form_def.kind { + DataInstKind::QPtr( + QPtrOp::Offset(offset) | QPtrOp::Load { offset } | QPtrOp::Store { offset }, + ) => { + *offset = + offset.checked_sub(partition_offset.get().try_into().unwrap()).unwrap(); + } + _ => unreachable!(), + } + data_inst_def.form = cx.intern(data_inst_form_def); + } + } + } +} + +struct CollectLocalVarPartitions<'a> { + cx: Rc, + layout_cache: LayoutCache<'a>, + vars: FxIndexMap, +} + +struct LocalVarPartitions { + parent_block: ControlNode, + original_layout: shapes::MemLayout, + // HACK(eddyb) offset `0` reuses the original variable and is tracked separately, + // to reduce the cost for both the collection, and to make replacement a noop. + zero_offset_partition_size: u32, + non_zero_offset_to_partition: BTreeMap, +} + +#[derive(Default)] +struct Partition { + size: u32, + + /// All the `DataInst`s that have a `QPtr` input and an immediate offset + /// (`QPtrOp::{Offset,Load,Store}`), which are updated after partitioning. + uses: SmallVec<[DataInst; 4]>, +} + +impl LocalVarPartitions { + /// Remove all partitions and prevent any further ones from being added + /// (typically only needed for variables used in unknown ways). + fn forfeit_partitioning(&mut self) { + self.zero_offset_partition_size = self.original_layout.size; + mem::take(&mut self.non_zero_offset_to_partition); + } + + /// Record a new partition range, and the `DataInst` it originated from, + /// merging ranges and uses with existing ones, in case of overlaps. + fn add_use(&mut self, range: Range, use_inst: DataInst) { + // FIXME(eddyb) the logic below is not amenable to ZSTs. + if range.is_empty() { + return self.forfeit_partitioning(); + } + + // The partition starting at `0` is special and does not track `uses`. + if range.start == 0 || range.start < self.zero_offset_partition_size { + self.zero_offset_partition_size = self.zero_offset_partition_size.max(range.end); + + // Absorb overlaps without keeping track of their `uses`. + while let Some(entry) = self.non_zero_offset_to_partition.first_entry() { + let partition_offset = entry.key().get(); + if range.end <= partition_offset { + break; + } + let partition = entry.remove(); + self.zero_offset_partition_size = + partition_offset.checked_add(partition.size).unwrap(); + } + return; + } + + let range = NonZeroU32::new(range.start).unwrap()..NonZeroU32::new(range.end).unwrap(); + let mut rev_overlapping_entries = self + .non_zero_offset_to_partition + .range_mut((Bound::Unbounded, Bound::Excluded(range.end))) + .rev() + .take_while(|(&partition_offset, partition)| { + partition_offset.checked_add(partition.size).unwrap() > range.start + }); + + // Fast path: `range` begins in an existing partition, and either already + // ends within it, or at least ends before the next existing partition + // (the second condition is guaranteed by this being the *last* overlap). + let mut last_overlapping_entry = rev_overlapping_entries.next(); + if let Some((&partition_offset, partition)) = &mut last_overlapping_entry { + if partition_offset <= range.start { + partition.size = partition.size.max(range.end.get() - partition_offset.get()); + partition.uses.push(use_inst); + return; + } + } + + let rev_overlapping_entries = + last_overlapping_entry.into_iter().chain(rev_overlapping_entries); + + // FIXME(eddyb) this is a bit inefficient but we don't have + // cursors, so we have to buffer the `BTreeMap` keys here. + let rev_overlapping_offsets: SmallVec<[_; 4]> = + rev_overlapping_entries.map(|(&offset, _)| offset).collect(); + + let merged_entry = rev_overlapping_offsets + .into_iter() + .rev() + .map(|offset| (offset, self.non_zero_offset_to_partition.remove(&offset).unwrap())) + .chain([( + range.start, + Partition { + size: range.end.get() - range.start.get(), + uses: [use_inst].into_iter().collect(), + }, + )]) + .reduce(|(a_start, a), (b_start, b)| { + let (a_end, b_end) = + (a_start.checked_add(a.size).unwrap(), b_start.checked_add(b.size).unwrap()); + let start = a_start.min(b_start); + let mut uses = a.uses; + uses.extend(b.uses); + (start, Partition { size: a_end.max(b_end).get() - start.get(), uses }) + }) + .unwrap(); + self.non_zero_offset_to_partition.extend([merged_entry]); + } +} + +impl Visitor<'_> for CollectLocalVarPartitions<'_> { + // FIXME(eddyb) this is excessive, maybe different kinds of + // visitors should exist for module-level and func-level? + fn visit_attr_set_use(&mut self, _: AttrSet) {} + fn visit_type_use(&mut self, _: Type) {} + fn visit_const_use(&mut self, _: Const) {} + fn visit_data_inst_form_use(&mut self, _: DataInstForm) {} + fn visit_global_var_use(&mut self, _: GlobalVar) {} + fn visit_func_use(&mut self, _: Func) {} + + // NOTE(eddyb) uses of variables that end up here disable partitioning of + // that variable, as they're not one of the special cases which are allowed. + fn visit_value_use(&mut self, &v: &Value) { + if let Value::DataInstOutput { inst, output_idx } = v { + if let Some(var) = self.vars.get_mut(&inst) { + assert_eq!(output_idx, 0); + var.forfeit_partitioning(); + } + } + } + + // FIXME(eddyb) we can't use `visit_data_inst_def` because we need either + // the resulting `DataInst`, or access to `FuncAt::type_of`. + fn visit_control_node_def(&mut self, func_at_control_node: FuncAt<'_, ControlNode>) { + if let ControlNodeKind::Block { insts } = func_at_control_node.def().kind { + for func_at_inst in func_at_control_node.at(insts) { + let data_inst_def = func_at_inst.def(); + let data_inst_form_def = &self.cx[data_inst_def.form]; + if let DataInstKind::QPtr(op) = &data_inst_form_def.kind { + let first_input_qptr_with_offset_and_access_type = match *op { + QPtrOp::FuncLocalVar(layout) => { + // FIXME(eddyb) support optional initializers. + if data_inst_def.inputs.is_empty() { + self.vars.insert( + func_at_inst.position, + LocalVarPartitions { + parent_block: func_at_control_node.position, + original_layout: layout, + zero_offset_partition_size: 0, + non_zero_offset_to_partition: BTreeMap::new(), + }, + ); + } + + None + } + + // FIXME(eddyb) support more uses of `qptr`s. + QPtrOp::Offset(offset) => { + // FIXME(eddyb) we could have a narrower range here, + // if it was recoded during `qptr::lower`. + Some((offset, None)) + } + QPtrOp::Load { offset } => { + Some((offset, Some(data_inst_form_def.output_types[0]))) + } + QPtrOp::Store { offset } => Some(( + offset, + Some(func_at_inst.at(data_inst_def.inputs[1]).type_of(&self.cx)), + )), + + _ => None, + }; + let first_input_var_with_offset_range = + first_input_qptr_with_offset_and_access_type.and_then( + |(offset, access_type)| { + if let Value::DataInstOutput { inst, output_idx } = + data_inst_def.inputs[0] + { + let var = self.vars.get_mut(&inst)?; + assert_eq!(output_idx, 0); + + let start = u32::try_from(offset).ok()?; + + let end = match access_type { + Some(ty) => match self.layout_cache.layout_of(ty).ok()? { + TypeLayout::Concrete(layout) + if layout.mem_layout.dyn_unit_stride.is_none() => + { + start.checked_add( + layout.mem_layout.fixed_base.size, + )? + } + _ => return None, + }, + None => var.original_layout.size, + }; + + Some((var, start..end)) + } else { + None + } + }, + ); + if let Some((var, offset_range)) = first_input_var_with_offset_range { + var.add_use(offset_range, func_at_inst.position); + + // Only visit the *other* inputs, not the `qptr` one. + for v in &data_inst_def.inputs[1..] { + self.visit_value_use(v); + } + + continue; + } + } + data_inst_def.inner_visit_with(self); + } + } else { + func_at_control_node.inner_visit_with(self); + } + } +} + +#[must_use] +#[derive(Default)] +pub struct PropagateLocalVarContentsReport { + /// Whether at least one of the function-local variables that had its contents + /// propagated, held a `qptr`, which may now allow further simplifications. + pub any_qptrs_propagated: bool, +} + +/// Propagate (from stores to loads) contents of `func_def_body`'s local variables. +pub fn propagate_contents_of_local_vars_in_func( + cx: Rc, + config: &LayoutConfig, + func_def_body: &mut FuncDefBody, +) -> PropagateLocalVarContentsReport { + let mut report = PropagateLocalVarContentsReport::default(); + + // Avoid having to support unstructured control-flow. + if func_def_body.unstructured_cfg.is_some() { + return report; + } + + let (vars, propagated_loads) = { + let mut propagator = PropagateLocalVarContents { + cx: &cx, + layout_cache: LayoutCache::new(cx.clone(), config), + vars: FxIndexMap::default(), + mutation_log: vec![], + propagated_loads: FxIndexMap::default(), + }; + propagator.propagate_through_control_region(func_def_body.at_mut_body()); + (propagator.vars, propagator.propagated_loads) + }; + + // FIXME(eddyb) this is not the most efficient way to compute this, but it + // should be straight-forwardly correct to do it here. + report.any_qptrs_propagated = vars + .values() + .filter_map(|var| var.as_ref().ok()?.ty) + .any(|ty| matches!(cx[ty].kind, TypeKind::QPtr)); + + let insts_to_remove = propagated_loads + .into_iter() + .map(|(original_inst, (_, parent_block))| (original_inst, parent_block)) + .chain(vars.into_iter().flat_map(|(original_var_inst, var_contents)| { + var_contents.ok().into_iter().flat_map(move |var_contents| { + [(original_var_inst, var_contents.parent_block)] + .into_iter() + .chain(var_contents.stores_with_parent_block) + }) + })); + for (inst, parent_block) in insts_to_remove { + match &mut func_def_body.control_nodes[parent_block].kind { + ControlNodeKind::Block { insts } => { + insts.remove(inst, &mut func_def_body.data_insts); + } + _ => unreachable!(), + } + } + + report +} + +struct PropagateLocalVarContents<'a> { + cx: &'a Context, + layout_cache: LayoutCache<'a>, + + vars: FxIndexMap>, + + // HACK(eddyb) this allows a flat representation, and handling `Select` + // control nodes at a cost proportional only to the number of variables + // modified in any of the child regions (not the total number of variables). + mutation_log: Vec, + + /// `QPtrOp::Load` instructions with known output `Value`s, and also tracking + /// their parent `Block` control node for later removal. + // + // FIXME(eddyb) it should be possible to remove the loads as they are seen. + propagated_loads: FxIndexMap, +} + +/// Error type for when a function-local variable's `LocalVarContents` cannot be +/// tracked, either because a pointer into it escapes, or there is some other +/// issue preventing tracking (e.g. layout error, type mismatch, etc.). +struct UnknowableLocalVar; + +struct LocalVarContents { + parent_block: ControlNode, + size: u32, + + /// Deduced type (of `value`, but may be present even if `value` is missing), + /// which cannot change once set (instead, `UnknowableLocalVar` is produced). + ty: Option, + + value: Option, + + /// `QPtrOp::Store` instructions to remove, if the whole variable is removed, + /// and their parent `Block` control node. + stores_with_parent_block: SmallVec<[(DataInst, ControlNode); 4]>, +} + +struct LocalVarMutation { + /// Index of the variable in the `vars` field of `PropagateLocalVarContents`. + var_idx: usize, + + /// Previous value of the `value` field of `LocalVarContents`. + prev_value: Option, +} + +struct LocalVarAccess<'a> { + /// Index of the variable in the `vars` field of `PropagateLocalVarContents`. + var_idx: usize, + + var: &'a mut LocalVarContents, + + /// If the local variable is an `OpTypeVector`, and this access is for one + /// of its scalar elements, this will contain that element's index. + vector_elem_idx: Option, +} + +impl PropagateLocalVarContents<'_> { + /// Validate an access into `var_qptr`, at `offset`, with type `access_type`, + /// returning `Some` if, and only if, the access does not conflict with any + /// previous ones, type-wise (with accesses smaller than the whole variable + /// being inferred as vector element accesses if a valid vector type fits). + /// + /// When `Some(access)` is returned, `access.var.ty` is guaranteed to be + /// `Some`, and the type of `access.var.value` (if the latter is present). + fn lookup_var_for_access( + &mut self, + var_qptr: Value, + offset: i32, + access_type: Type, + ) -> Option> { + // HACK(eddyb) we steal the `LocalVarContents` to make the logic below + // easier to write: if *anything* goes wrong, `Err(UnknowableLocalVar)` + // will be left behind, and `Ok(var)` will be be restored if and only if + // everything about this access is valid (and `Some` will be returned). + let (var_idx, mut var) = match var_qptr { + Value::DataInstOutput { inst, output_idx } => { + let (var_idx, _, var) = self.vars.get_full_mut(&inst)?; + assert_eq!(output_idx, 0); + (var_idx, mem::replace(var, Err(UnknowableLocalVar)).ok()?) + } + _ => return None, + }; + + let offset = u32::try_from(offset).ok()?; + + let layout = match self.layout_cache.layout_of(access_type).ok()? { + TypeLayout::Concrete(layout) if layout.mem_layout.dyn_unit_stride.is_none() => layout, + _ => return None, + }; + let access_size = layout.mem_layout.fixed_base.size; + + let (inferred_var_type, vector_elem_idx) = if offset == 0 && access_size == var.size { + (layout.original_type, None) + } else { + // HACK(eddyb) we only support vector types here, as + // they're the most common cause of partial loads/stores. + let inferred_vector_len = var.size / access_size; + let elem_idx = offset / access_size; + + let scalar_access_type = access_type.as_scalar(self.cx)?; + let legal_vector = var.size % access_size == 0 + && offset % access_size == 0 + && (2..=4).contains(&inferred_vector_len); + if !legal_vector { + return None; + } + ( + self.cx.intern(vector::Type { + elem: scalar_access_type, + elem_count: u8::try_from(inferred_vector_len).ok()?.try_into().ok()?, + }), + Some(u8::try_from(elem_idx).unwrap()), + ) + }; + + if var.ty.is_some_and(|ty| ty != inferred_var_type) { + return None; + } + var.ty = Some(inferred_var_type); + + self.vars[var_idx] = Ok(var); + let var = self.vars[var_idx].as_mut().ok().unwrap(); + + // FIXME(eddyb) should the returned value not even contain a reference + // into `self.vars`, given that it's entirely relying on indexing? + Some(LocalVarAccess { var_idx, var, vector_elem_idx }) + } + + /// Apply active rewrites (i.e. `propagated_loads`) to all `values`. + fn propagate_into_values(&mut self, values: &mut [Value]) { + for v in values { + if let Value::DataInstOutput { inst, output_idx } = *v { + if let Some(&(replacement_value, _)) = self.propagated_loads.get(&inst) { + assert_eq!(output_idx, 0); + *v = replacement_value; + } + } + } + } + + /// Record `values` as used - this is expected to be called only after + /// `propagate_into_values` was applied, and not to include `qptr`s which + /// were part of propagated loads/stores, as this'd mark them as unknowable. + fn track_value_uses(&mut self, values: &[Value]) { + for &v in values { + if let Value::DataInstOutput { inst, output_idx } = v { + if let Some(var) = self.vars.get_mut(&inst) { + assert_eq!(output_idx, 0); + *var = Err(UnknowableLocalVar); + } + } + } + } + + fn propagate_through_data_inst( + &mut self, + mut func_at_inst: FuncAtMut<'_, DataInst>, + parent_block: ControlNode, + ) { + let cx = self.cx; + + let const_undef = |ty| { + Value::Const(cx.intern(ConstDef { + attrs: AttrSet::default(), + ty, + kind: ConstKind::Undef, + })) + }; + + let data_inst = func_at_inst.position; + + let DataInstDef { attrs: _, form, inputs } = func_at_inst.reborrow().def(); + + let DataInstFormDef { kind, output_types } = &cx[*form]; + + // FIXME(eddyb) it may be helpful to fold uses after propagation, + // (e.g. `qptr.offset` into `qptr.{load,store}`), to allow propagation + // of variables who had their pointers stored in other variables - note + // that multiple propagation passes would *still* be needed, because the + // original store of a pointer to a variable will make it unknowable. + self.propagate_into_values(inputs); + + match *kind { + DataInstKind::QPtr(QPtrOp::FuncLocalVar(layout)) => { + assert!(inputs.len() <= 1); + let init_value = inputs.first().copied(); + + self.vars.insert( + data_inst, + Ok(LocalVarContents { + parent_block, + size: layout.size, + ty: init_value.map(|v| func_at_inst.reborrow().freeze().at(v).type_of(cx)), + value: init_value, + stores_with_parent_block: Default::default(), + }), + ); + } + + DataInstKind::QPtr(QPtrOp::Load { offset }) => { + assert_eq!(inputs.len(), 1); + let src_ptr = inputs[0]; + + if let Some(access) = self.lookup_var_for_access(src_ptr, offset, output_types[0]) { + let var_ty = access.var.ty.unwrap(); + + // HACK(eddyb) cache the `OpUndef` constant in-place. + let var_value = *access.var.value.get_or_insert_with(|| const_undef(var_ty)); + + match access.vector_elem_idx { + None => { + self.propagated_loads.insert(data_inst, (var_value, parent_block)); + // FIXME(eddyb) maybe remove the instruction here and now? + } + + // Element loads from vector variables don't need to + // have their uses replaced, but rather become extracts. + Some(elem_idx) => { + *form = cx.intern(DataInstFormDef { + kind: vector::Op::from(vector::WholeOp::Extract { elem_idx }) + .into(), + output_types: output_types.clone(), + }); + *inputs = [var_value].into_iter().collect(); + } + } + + return; + } + } + + DataInstKind::QPtr(QPtrOp::Store { offset }) => { + assert_eq!(inputs.len(), 2); + let dst_ptr = inputs[0]; + let stored_value = inputs[1]; + + if let Some(access) = self.lookup_var_for_access( + dst_ptr, + offset, + func_at_inst.reborrow().freeze().at(stored_value).type_of(cx), + ) { + let var_ty = access.var.ty.unwrap(); + + let new_var_value = match access.vector_elem_idx { + None => stored_value, + + // Element stores into vector variables become inserts, + // but because we don't know yet if the store will be + // removed (as the variable can still escape later, or + // change type, etc.), the insert needs to be separate. + Some(elem_idx) => { + // HACK(eddyb) cache the `OpUndef` constant in-place + // (this may seem unnecessary, but the `mutation_log` + // will record the `OpUndef` as the `prev_value`). + let var_value = + *access.var.value.get_or_insert_with(|| const_undef(var_ty)); + + let vector_insert_data_inst = + func_at_inst.reborrow().data_insts.define( + cx, + DataInstDef { + // FIXME(eddyb) preserve at least debuginfo attrs. + attrs: Default::default(), + form: cx.intern(DataInstFormDef { + kind: vector::Op::from(vector::WholeOp::Insert { + elem_idx, + }) + .into(), + output_types: [var_ty].into_iter().collect(), + }), + inputs: [stored_value, var_value].into_iter().collect(), + } + .into(), + ); + + // HACK(eddyb) can't really use helpers like `FuncAtMut::def`, + // due to the need to borrow `control_nodes` and `data_insts` + // at the same time - perhaps some kind of `FuncAtMut` position + // types for "where a list is in a parent entity" could be used + // to make this more ergonomic, although the potential need for + // an actual list entity of its own, should be considered. + let func = func_at_inst.reborrow().at(()); + match &mut func.control_nodes[parent_block].kind { + ControlNodeKind::Block { insts } => { + insts.insert_before( + vector_insert_data_inst, + data_inst, + func.data_insts, + ); + } + _ => unreachable!(), + } + + Value::DataInstOutput { inst: vector_insert_data_inst, output_idx: 0 } + } + }; + + let prev_value = access.var.value.replace(new_var_value); + access.var.stores_with_parent_block.push((data_inst, parent_block)); + let var_idx = access.var_idx; + self.mutation_log.push(LocalVarMutation { var_idx, prev_value }); + + // Only visit the value input, not the destination pointer. + self.track_value_uses(&[stored_value]); + + return; + } + } + + _ => {} + } + + self.track_value_uses(&func_at_inst.def().inputs); + } + + fn propagate_through_control_region( + &mut self, + mut func_at_region: FuncAtMut<'_, ControlRegion>, + ) { + let mut children = func_at_region.reborrow().at_children().into_iter(); + while let Some(func_at_control_node) = children.next() { + self.propagate_through_control_node(func_at_control_node); + } + + let ControlRegionDef { inputs: _, children: _, outputs } = func_at_region.def(); + self.propagate_into_values(outputs); + self.track_value_uses(outputs); + } + + fn propagate_through_control_node(&mut self, func_at_control_node: FuncAtMut<'_, ControlNode>) { + let const_undef = |ty| { + Value::Const(self.cx.intern(ConstDef { + attrs: AttrSet::default(), + ty, + kind: ConstKind::Undef, + })) + }; + + let control_node = func_at_control_node.position; + + // FIXME(eddyb) is this a good convention? + let mut func = func_at_control_node.at(()); + + match &mut func.reborrow().at(control_node).def().kind { + &mut ControlNodeKind::Block { insts } => { + let mut func_at_inst_iter = func.at(insts).into_iter(); + while let Some(func_at_inst) = func_at_inst_iter.next() { + self.propagate_through_data_inst(func_at_inst, control_node); + } + } + ControlNodeKind::Select { kind: _, scrutinee, cases } => { + self.propagate_into_values(slice::from_mut(scrutinee)); + self.track_value_uses(&[*scrutinee]); + + let num_cases = cases.len(); + + // FIXME(eddyb) represent the list of child regions without having them + // in a `Vec` (or `SmallVec`), which requires workarounds like this. + let get_case = |func: FuncAtMut<'_, ()>, i| match &func.at(control_node).def().kind + { + ControlNodeKind::Select { cases, .. } => cases[i], + _ => unreachable!(), + }; + + // HACK(eddyb) degenerate `Select`s do not actually need merges. + if num_cases <= 1 { + if num_cases == 1 { + let case = get_case(func.reborrow(), 0); + self.propagate_through_control_region(func.at(case)); + } + return; + } + + // HACK(eddyb) this is how we can both roll back changes to + // variables' `value`s, and know which variables were changed + // in the first place (to merge their changes values, together). + let mutation_log_start = self.mutation_log.len(); + + let mut var_idx_to_per_case_values = + FxIndexMap::>::default(); + for case_idx in 0..num_cases { + let case = get_case(func.reborrow(), case_idx); + self.propagate_through_control_region(func.reborrow().at(case)); + + // NOTE(eddyb) we traverse the mutation log forwards, as we + // already have a way to determine whether we've seen any + // mutations for each variable, and only the oldest mutation + // is needed to roll back the variable to its original state. + for mutation in self.mutation_log.drain(mutation_log_start..) { + let original_var_value = mutation.prev_value; + if let Ok(var) = &mut self.vars[mutation.var_idx] { + let per_case_var_values = var_idx_to_per_case_values + .entry(mutation.var_idx) + .or_insert_with(|| { + let mut per_case_var_values = + SmallVec::with_capacity(num_cases); + + // This case may be the first to mutate this + // variable - thankfully we know the original + // value (which will be common across all cases). + per_case_var_values + .extend((0..case_idx).map(|_| original_var_value)); + + per_case_var_values + }); + + if per_case_var_values.len() <= case_idx { + let new_var_value = + mem::replace(&mut var.value, original_var_value); + per_case_var_values.push(new_var_value); + } + assert_eq!(per_case_var_values.len() - 1, case_idx); + } + } + + // Some variables may only have been mutated in previous cases. + for (&var_idx, per_case_var_values) in &mut var_idx_to_per_case_values { + if per_case_var_values.len() <= case_idx { + if let Ok(var) = &self.vars[var_idx] { + per_case_var_values.push(var.value); + assert_eq!(per_case_var_values.len() - 1, case_idx); + } + } + } + } + + // Variables mutated in at least one case can now be merged, + // by creating `Select` outputs for all of them. + for (var_idx, per_case_var_values) in var_idx_to_per_case_values { + if let Ok(var) = &mut self.vars[var_idx] { + assert_eq!(per_case_var_values.len(), num_cases); + + // HACK(eddyb) do not create outputs if all cases agree. + let v0 = per_case_var_values[0]; + if per_case_var_values[1..].iter().all(|&v| v == v0) { + let prev_value = mem::replace(&mut var.value, v0); + self.mutation_log.push(LocalVarMutation { var_idx, prev_value }); + continue; + } + + let var_ty = var.ty.unwrap(); + + let select_output_decls = + &mut func.reborrow().at(control_node).def().outputs; + let output_idx = u32::try_from(select_output_decls.len()).unwrap(); + select_output_decls + .push(ControlNodeOutputDecl { attrs: Default::default(), ty: var_ty }); + + // FIXME(eddyb) avoid random access, perhaps by handling + // variables per-case, instead of cases per-variable. + for (case_idx, per_case_var_value) in + (0..num_cases).zip(per_case_var_values) + { + let case = get_case(func.reborrow(), case_idx); + let per_case_outputs = &mut func.reborrow().at(case).def().outputs; + assert_eq!(per_case_outputs.len(), output_idx as usize); + per_case_outputs + .push(per_case_var_value.unwrap_or_else(|| const_undef(var_ty))); + } + + let prev_value = var + .value + .replace(Value::ControlNodeOutput { control_node, output_idx }); + self.mutation_log.push(LocalVarMutation { var_idx, prev_value }); + } + } + } + ControlNodeKind::Loop { initial_inputs, body, repeat_condition: _ } => { + self.propagate_into_values(initial_inputs); + self.track_value_uses(initial_inputs); + + let body = *body; + + // HACK(eddyb) as the body of the loop may execute multiple times, + // the initial states of variables have to account for potential + // mutations in previous iterations, which we detect with this + // separate visitor, then plumb through the region inputs/outputs. + let mut mutated_var_indices = { + let mut mutation_finder = FindMutatedLocalVars { + propagator: self, + mutated_var_indices: FxIndexSet::default(), + }; + mutation_finder.visit_control_region_def(func.reborrow().freeze().at(body)); + mutation_finder.mutated_var_indices + }; + mutated_var_indices.retain(|&var_idx| match &mut self.vars[var_idx] { + Ok(var) => { + let var_ty = var.ty.unwrap(); + + let body_input_decls = &mut func.reborrow().at(body).def().inputs; + let input_idx = u32::try_from(body_input_decls.len()).unwrap(); + body_input_decls + .push(ControlRegionInputDecl { attrs: Default::default(), ty: var_ty }); + + let prev_value = var + .value + .replace(Value::ControlRegionInput { region: body, input_idx }); + + let initial_inputs = match &mut func.reborrow().at(control_node).def().kind + { + ControlNodeKind::Loop { initial_inputs, .. } => initial_inputs, + _ => unreachable!(), + }; + assert_eq!(initial_inputs.len(), input_idx as usize); + initial_inputs.push(prev_value.unwrap_or_else(|| const_undef(var_ty))); + + // NOTE(eddyb) can't avoid this, because the original + // values of mutated variables would otherwise be lost. + self.mutation_log.push(LocalVarMutation { var_idx, prev_value }); + + true + } + Err(_) => false, + }); + + let body_mutation_log_start = self.mutation_log.len(); + self.propagate_through_control_region(func.reborrow().at(body)); + + // Record the updated values of variables, for future iterations. + let body_outputs = &mut func.reborrow().at(body).def().outputs; + body_outputs.extend(mutated_var_indices.iter().map(|&var_idx| { + // HACK(eddyb) we require `FindMutatedLocalVars` to perfectly + // model all the situations in which we may reach an error + // (i.e. `UnknowableLocalVar`), and in which variables get + // mutated, because we may have *already* replaced loads in + // `body` to refer to values stored *in previous iterations*, + // so we need those values to actually be always usable. + self.vars[var_idx].as_ref().ok().unwrap().value.unwrap() + })); + + // HACK(eddyb) because we already recorded all the mutations + // based on `mutated_var_indices` alone, we can discard all the + // redundant log entries (this also doubles as a sanity check). + // FIXME(eddyb) this requires two passes to avoid new allocations + // for deduplicating the set mutated variables - perhaps it may + // be possible for `mutation_log` to always deduplicate itself + // "since the most recent snapshot" or something? + for mutation in &self.mutation_log[body_mutation_log_start..] { + assert!(mutated_var_indices.contains(&mutation.var_idx)); + } + for mutation in self.mutation_log.drain(body_mutation_log_start..) { + mutated_var_indices.swap_remove(&mutation.var_idx); + } + assert_eq!(mutated_var_indices.len(), 0); + + let repeat_condition = match &mut func.at(control_node).def().kind { + ControlNodeKind::Loop { repeat_condition, .. } => repeat_condition, + _ => unreachable!(), + }; + self.propagate_into_values(slice::from_mut(repeat_condition)); + self.track_value_uses(&[*repeat_condition]); + } + } + } +} + +/// Helper `Visitor` used when propagating local variables across a `Loop`, to +/// determine *ahead of time* which variables require `ControlRegion` inputs. +struct FindMutatedLocalVars<'a, 'b> { + propagator: &'a mut PropagateLocalVarContents<'b>, + + /// Indices of mutated variables, in the `propagator.vars` `IndexMap`. + // FIXME(eddyb) this could probably be a compact bitset. + // FIXME(eddyb) a more accurate check would also consider whether values from + // previous iterations (or before the loop) are needed, not just mutations. + mutated_var_indices: FxIndexSet, +} + +impl Visitor<'_> for FindMutatedLocalVars<'_, '_> { + // FIXME(eddyb) this is excessive, maybe different kinds of + // visitors should exist for module-level and func-level? + fn visit_attr_set_use(&mut self, _: AttrSet) {} + fn visit_type_use(&mut self, _: Type) {} + fn visit_const_use(&mut self, _: Const) {} + fn visit_data_inst_form_use(&mut self, _: DataInstForm) {} + fn visit_global_var_use(&mut self, _: GlobalVar) {} + fn visit_func_use(&mut self, _: Func) {} + + // NOTE(eddyb) uses of variables that end up here disable tracking of + // that variable's contents (see also `UnknowableLocalVar`). + fn visit_value_use(&mut self, &v: &Value) { + if let Value::DataInstOutput { inst, output_idx } = v { + if let Some(var) = self.propagator.vars.get_mut(&inst) { + assert_eq!(output_idx, 0); + *var = Err(UnknowableLocalVar); + } + } + } + + // FIXME(eddyb) we can't use `visit_data_inst_def` because we need either + // the resulting `DataInst`, or access to `FuncAt::type_of`. + fn visit_control_node_def(&mut self, func_at_control_node: FuncAt<'_, ControlNode>) { + if let ControlNodeKind::Block { insts } = func_at_control_node.def().kind { + for func_at_inst in func_at_control_node.at(insts) { + let data_inst_def = func_at_inst.def(); + let data_inst_form_def = &self.propagator.cx[data_inst_def.form]; + if let DataInstKind::QPtr(op) = &data_inst_form_def.kind { + let first_input_qptr_with_offset_and_access_type = match *op { + // HACK(eddyb) declaring local variables in loops is unsupported. + QPtrOp::FuncLocalVar(_) => { + self.propagator + .vars + .insert(func_at_inst.position, Err(UnknowableLocalVar)); + + None + } + + // NOTE(eddyb) these need to match up exactly with + // `propagate_through_data_inst`, for correctness. + QPtrOp::Load { offset } => { + Some((offset, data_inst_form_def.output_types[0])) + } + QPtrOp::Store { offset } => Some(( + offset, + func_at_inst.at(data_inst_def.inputs[1]).type_of(self.propagator.cx), + )), + + _ => None, + }; + if let Some((offset, access_type)) = + first_input_qptr_with_offset_and_access_type + { + if let Some(access) = self.propagator.lookup_var_for_access( + data_inst_def.inputs[0], + offset, + access_type, + ) { + // FIXME(eddyb) a more accurate check would also + // consider whether values from previous iterations + // (or before the loop) are needed, not just mutations. + let _needs_previous_value = matches!(op, QPtrOp::Load { .. }) + || access.vector_elem_idx.is_some(); + + if let QPtrOp::Store { .. } = op { + self.mutated_var_indices.insert(access.var_idx); + } + + // Only visit the *other* inputs, not the `qptr` one. + for v in &data_inst_def.inputs[1..] { + self.visit_value_use(v); + } + + continue; + } + } + } + data_inst_def.inner_visit_with(self); + } + } else { + func_at_control_node.inner_visit_with(self); + } + } +} From 4fce1b7511aee72485eacd500601c5e96ad67c62 Mon Sep 17 00:00:00 2001 From: Eduard-Mihai Burtescu Date: Wed, 11 Oct 2023 05:15:00 +0300 Subject: [PATCH 22/22] [WIP] add `flow` analysis framework experiment. --- src/flow.rs | 956 +++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 3 +- src/passes/qptr.rs | 5 + 3 files changed, 963 insertions(+), 1 deletion(-) create mode 100644 src/flow.rs diff --git a/src/flow.rs b/src/flow.rs new file mode 100644 index 00000000..bc11c491 --- /dev/null +++ b/src/flow.rs @@ -0,0 +1,956 @@ +//! Flow-sensitive (side-effect) analysis framework. + +use crate::func_at::{FuncAt, FuncAtMut}; +use crate::qptr::QPtrOp; +use crate::{ + AttrSet, ConstDef, ConstKind, Context, ControlNodeOutputDecl, ControlRegion, ControlRegionDef, + ControlRegionInputDecl, DataInst, DataInstDef, DataInstKind, Diag, FuncDefBody, FxIndexMap, + FxIndexSet, Type, +}; +use crate::{ControlNode, Value}; +use crate::{ControlNodeKind, DataInstFormDef}; +use smallvec::SmallVec; +use std::cell::Cell; +use std::collections::VecDeque; +use std::rc::Rc; +use std::{mem, slice}; + +// FIXME(eddyb) make the whole newtyped indexing situation better +// (using `EntityDefs` may make sense, but those are unique `Context`-wide). + +// FIXME(eddyb) switch to `u32` once the logic is shown to work. +type Idx = usize; + +// FIXME(eddyb) consider shortening to "Cap" and even maybe "Obj". + +/// Handle for a "capability" (see [`CapabilityDef`]). +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +struct Capability(Idx); + +/// "Capabilities" are [`Value`]s used to symbolically refer to [`Object`]s, +/// e.g. pointers referring to memory (sub)objects (aka "provenance"). +enum CapabilityDef { + // HACK(eddyb) singleton "escaped object set" quasi-capability. + Ambient { + // FIXME(eddyb) this should probably be a bitset. + reachable_objects: FxIndexSet, + }, + + WholeObject(Object), + // + // FIXME(eddyb) implement more cases, including "object slicing" but also + // merging capabilities across selects/loops into conservative sets. +} + +/// Handle for an "object" (see [`ObjectDef`]). +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +struct Object(Idx); + +/// "Objects" are statically known (and analyzable) disjoint "partitions" of the +/// state (in the (R)VSDG sense) during execution, e.g. the memory of a variable. +struct ObjectDef { + known_state: Option, + + /// Whether this object (its allocation and any instructions operating on it), + /// needs to be preserved for any reason whatsoever. + /// + /// This flag starts out `false` and will (permanently) become `true` after + /// *any* operation other than writes and reads that can be rewritten away. + // + // FIXME(eddyb) ththis name is somewhat arbitrary and suboptimal. + // FIXME(eddyb) much more detailed state-oriented tracking should be possible. + keep_alive: bool, + // + // FIXME(eddyb) also track other objects (or even capabilities) reachable + // through this object, due to writes into it, instead of making them ambient. +} + +/// Handle for an "object state" (see [`ObjectStateDef`]). +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +struct ObjectState(Idx); + +// FIXME(eddyb) should this use `Rc`? should this try to be flatter? +// should it have "version"/"revision"/"snapshot" etc. in the name? +#[derive(PartialEq, Eq, Hash)] +enum ObjectStateDef { + AfterInst { + prev: Option, + inst: DataInst, + // HACK(eddyb) parent block of `inst`, only needed to be able to delete it. + parent_block: ControlNode, + }, + AfterSelect { + select_node: ControlNode, + state_before_select: Option, + per_case_states: SmallVec<[ObjectState; 2]>, + }, + // HACK(eddyb) the `ObjectState`s for before-loop/after-body are kept + // in a side table (`loop_object_state_initial_and_after_body`), and + // the information here only serves as a key for interning, so that + // the whole loop can be repeatedly processed to arrive at a fixpoint. + BeforeLoopBody { + loop_node: ControlNode, + object: Object, + }, +} + +// FIXME(eddyb) move stuff over to this. +struct InstInOutEffects { + // FIXME(eddyb) will these need to be dynamic? + inputs: &'static [Option], + outputs: &'static [Option], +} + +// FIXME(eddyb) should this be like a bitmask? +enum InputSink { + // FIXME(eddyb) implement object slicing. + Capability { + object_read: bool, + object_write: bool, + }, + + /// Value being written to an [`Object`] (referenced via a [`Capability`] input). + ObjectWriteValue, +} + +enum OutputSource { + /// Allocate a new [`Object`] to output as a [`Capability`]. + NewObjectCapability, + + /// Value being read from an [`Object`] (referenced via a [`Capability`] input). + ObjectReadValue, +} + +/// Apply some flow analysis to `func_def_body`. +pub fn flow_func(cx: Rc, func_def_body: &mut FuncDefBody) { + // Avoid having to support unstructured control-flow. + // + // FIXME(eddyb) it should be possible to actually implement this, using + // classic dataflow techniques, and φ/BB args for value propagation. + if func_def_body.unstructured_cfg.is_some() { + return; + } + + let (objects, object_states, loop_object_state_initial_and_after_body, replaceable_insts) = { + let mut flow_cx = FlowContext { + cx: &cx, + ambient_capability: Capability(0), + value_to_capability: FxIndexMap::default(), + capabilities: vec![CapabilityDef::Ambient { reachable_objects: FxIndexSet::default() }], + objects: vec![], + object_state_interner: FxIndexSet::default(), + loop_object_state_initial_and_after_body: FxIndexMap::default(), + change_log: vec![], + replaceable_insts: FxIndexMap::default(), + }; + flow_cx.flow_through_control_region(func_def_body.at_mut_body()); + ( + flow_cx.objects, + flow_cx.object_state_interner, + flow_cx.loop_object_state_initial_and_after_body, + flow_cx.replaceable_insts, + ) + }; + + let mut remove_inst = + |inst, parent_block| match &mut func_def_body.control_nodes[parent_block].kind { + ControlNodeKind::Block { insts } => { + insts.remove(inst, &mut func_def_body.data_insts); + } + _ => unreachable!(), + }; + for (original_inst, (_, parent_block)) in replaceable_insts { + remove_inst(original_inst, parent_block); + } + + // HACK(eddyb) queue + set to safely visit all the `ObjectState`s used + // by non-`keep_alive` `Object`s. + let mut state_visiting_queue = VecDeque::new(); + let mut visited_states = FxIndexSet::default(); + state_visiting_queue.extend( + objects.into_iter().filter(|obj| !obj.keep_alive).filter_map(|obj| obj.known_state), + ); + while let Some(state) = state_visiting_queue.pop_front() { + if !visited_states.insert(state) { + continue; + } + match &object_states[state.0] { + &ObjectStateDef::AfterInst { prev, inst, parent_block } => { + remove_inst(inst, parent_block); + state_visiting_queue.extend(prev); + } + ObjectStateDef::AfterSelect { state_before_select, per_case_states, .. } => { + state_visiting_queue.extend( + (*state_before_select).into_iter().chain(per_case_states.iter().copied()), + ); + } + ObjectStateDef::BeforeLoopBody { .. } => { + if let Some(&(initial_state, state_after_body)) = + loop_object_state_initial_and_after_body.get(&state) + { + state_visiting_queue.extend([initial_state, state_after_body]); + } + } + } + } +} + +struct FlowContext<'a> { + cx: &'a Context, + + // FIXME(eddyb) this will always be `Capability(0)` anyway, use constant? + ambient_capability: Capability, + + // FIXME(eddyb) use a much denser (or at least more efficient) representation, + // and/or reuse this for non-capability values as well. + value_to_capability: FxIndexMap, + + capabilities: Vec, + objects: Vec, + + object_state_interner: FxIndexSet, + + // HACK(eddyb) see `ObjectStateDef::BeforeLoopBody` docs for more details. + loop_object_state_initial_and_after_body: FxIndexMap, + + // HACK(eddyb) this allows a flat representation, and handling `Select` + // control nodes at a cost proportional only to the number of changes + // in any of the child regions (not the total number of e.g. objects). + change_log: Vec, + + /// Object read instructions with known output `Value`s, and also tracking + /// their parent `Block` control node for later removal. + // + // FIXME(eddyb) it should be possible to remove these as they are seen. + replaceable_insts: FxIndexMap; 2]>, ControlNode)>, +} + +enum Change { + // FIXME(eddyb) does this only exist to detect allocations in loops? + AllocateNewObject { + new_object: Object, + alloc_inst: DataInst, + }, + + /// This [`Object`] is new in [`CapabilityDef::Ambient`]'s reachable set. + EscapeObjectToAmbient(Object), + + ObjectState { + object: Object, + + /// Previous value of the `known_state` field of `ObjectDef`. + prev_known_state: Option, + }, +} + +// HACK(eddyb) helper for `query_past_value`. +enum QueryStep { + Done(V), + + /// This write does not overlap with the queried part of the object at all. + DisjointWrite, + // FIXME(eddyb) add some way to handle partially overlapping reads with writes, + // where another callback has to combine multiple partial values together. +} + +/// Error type for `query_past_value` (or the callback it takes) failing to find +/// known states that are usable (i.e. containing compatible writes). +#[derive(Copy, Clone)] +struct UnknowableValue; + +// HACK(eddyb) this avoids using a map in some places. +struct ChainedMiniCache<'a, K, V> { + key: K, + value: Cell>, + // FIXME(eddyb) this is only used when `value` is `None`, maybe use a + // two-state `enum` instead, so this field is only used for "lazy init"? + prev: Option<&'a Self>, +} + +impl FlowContext<'_> { + fn intern_object_state(&mut self, object_state_def: ObjectStateDef) -> ObjectState { + ObjectState(match self.object_state_interner.get_full(&object_state_def) { + Some((idx, _)) => idx, + None => self.object_state_interner.insert_full(object_state_def).0, + }) + } + + /// Extract and/or synthesize a [`Value`], of type `value_type`, if possible, + /// from past [`Object`] writes, relative to the `current` [`ObjectState`], + /// using a caller-provided `DataInst`-level filter/map `step` callback. + fn query_past_value( + &self, + func: FuncAtMut<'_, ()>, + current: ObjectState, + value_type: Type, + step: impl Fn(FuncAt<'_, DataInst>) -> Result, UnknowableValue>, + ) -> Result { + self.query_past_value_inner(func, current, value_type, None, &step) + } + + // HACK(eddyb) implementation detail of `query_past_value`. + // FIXME(eddyb) figure out some kind of localized "undo log" or similar, + // for any additions this makes to the function, so in case of errors + // it's all removed without leaving any unused inputs/outputs around. + fn query_past_value_inner( + &self, + mut func: FuncAtMut<'_, ()>, + current: ObjectState, + value_type: Type, + mini_cache: Option<&ChainedMiniCache<'_, ObjectState, Result>>, + step: &impl Fn(FuncAt<'_, DataInst>) -> Result, UnknowableValue>, + ) -> Result { + if let Some(mini_cache) = mini_cache { + if mini_cache.key == current { + return mini_cache.value.get().unwrap_or_else(|| { + // HACK(eddyb) this is easier than `try {...}`-ing the rest of the code. + let r = self.query_past_value_inner( + func, + current, + value_type, + mini_cache.prev, + step, + ); + mini_cache.value.set(Some(r)); + r + }); + } + } + + // FIXME(eddyb) avoid recursion, ideally find a good way to do caching! + match &self.object_state_interner[current.0] { + &ObjectStateDef::AfterInst { prev, inst, .. } => { + match step(func.reborrow().freeze().at(inst))? { + QueryStep::Done(v) => Ok(v), + QueryStep::DisjointWrite => self.query_past_value_inner( + func, + prev.ok_or(UnknowableValue)?, + value_type, + mini_cache, + step, + ), + } + } + ObjectStateDef::AfterSelect { select_node, state_before_select, per_case_states } => { + let select_node = *select_node; + + // FIXME(eddyb) represent the list of child regions without having them + // in a `Vec` (or `SmallVec`), which requires workarounds like this. + let get_case = |func: FuncAtMut<'_, ()>, i| match &func.at(select_node).def().kind { + ControlNodeKind::Select { cases, .. } => cases[i], + _ => unreachable!(), + }; + + let mini_cache_before_select = + state_before_select.map(|state_before_select| ChainedMiniCache { + key: state_before_select, + value: Cell::new(None), + // HACK(eddyb) this ensures that once `state_before_select` + // is encountered, the caching switching to some outer + // *even earlier* state (e.g. before an outer select). + prev: mini_cache, + }); + let mini_cache = mini_cache_before_select.as_ref().or(mini_cache); + + let per_case_values: SmallVec<[_; 2]> = per_case_states + .iter() + .map(|&state| { + self.query_past_value_inner( + func.reborrow(), + state, + value_type, + mini_cache, + step, + ) + }) + .collect::>()?; + + // HACK(eddyb) do not create outputs if all cases agree. + let v0 = per_case_values[0]; + if per_case_values[1..].iter().all(|&v| v == v0) { + return Ok(v0); + } + + let select_output_decls = &mut func.reborrow().at(select_node).def().outputs; + let output_idx = u32::try_from(select_output_decls.len()).unwrap(); + select_output_decls + .push(ControlNodeOutputDecl { attrs: Default::default(), ty: value_type }); + + for (case_idx, per_case_value) in per_case_values.into_iter().enumerate() { + let case = get_case(func.reborrow(), case_idx); + let per_case_outputs = &mut func.reborrow().at(case).def().outputs; + assert_eq!(per_case_outputs.len(), output_idx as usize); + per_case_outputs.push(per_case_value); + } + + Ok(Value::ControlNodeOutput { control_node: select_node, output_idx }) + } + &ObjectStateDef::BeforeLoopBody { loop_node, .. } => { + let &(initial_state, state_after_body) = self + .loop_object_state_initial_and_after_body + .get(¤t) + .ok_or(UnknowableValue)?; + + let initial_value = self.query_past_value_inner( + func.reborrow(), + initial_state, + value_type, + mini_cache, + step, + )?; + + let (initial_inputs, body) = match &mut func.reborrow().at(loop_node).def().kind { + ControlNodeKind::Loop { initial_inputs, body, .. } => (initial_inputs, *body), + _ => unreachable!(), + }; + let input_idx = u32::try_from(initial_inputs.len()).unwrap(); + initial_inputs.push(initial_value); + + let body_input_decls = &mut func.reborrow().at(body).def().inputs; + assert_eq!(body_input_decls.len(), input_idx as usize); + body_input_decls + .push(ControlRegionInputDecl { attrs: Default::default(), ty: value_type }); + + let new_body_input = Value::ControlRegionInput { region: body, input_idx }; + + let value_after_body = self.query_past_value_inner( + func, + state_after_body, + value_type, + // HACK(eddyb) this avoids infinite recursion by caching + // the same value being later returned, with the same key. + Some(&ChainedMiniCache { + key: current, + value: Cell::new(Some(Ok(new_body_input))), + prev: None, + }), + step, + )?; + + // HACK(eddyb) ignore the loop entirely if its body isn't relevant. + // FIXME(eddyb) this should result in `new_body_input` being removed. + if value_after_body == new_body_input { + return Ok(initial_value); + } + + Ok(new_body_input) + } + } + } + + /// Apply active rewrites (i.e. `replaceable_insts`) to all `values`. + fn flow_into_values(&mut self, values: &mut [Value]) { + for v in values { + // FIXME(eddyb) should this run in a loop? + if let Value::DataInstOutput { inst, output_idx } = *v { + if let Some((replacement_values, _)) = self.replaceable_insts.get(&inst) { + if let Some(rv) = replacement_values[output_idx as usize] { + *v = rv; + } + } + } + } + } + + // HACK(eddyb) this (somewhat inefficiently?) erases all known states from + // ambiently reachable ("escaped") objects. + fn unanalyzable_effects(&mut self) { + let ambiently_reachable_objects = match &self.capabilities[self.ambient_capability.0] { + CapabilityDef::Ambient { reachable_objects } => reachable_objects, + _ => unreachable!(), + }; + for &obj in ambiently_reachable_objects { + let prev = self.objects[obj.0].known_state.take(); + if prev.is_some() { + self.change_log.push(Change::ObjectState { object: obj, prev_known_state: prev }); + } + } + } + + fn escape_object_to_ambient(&mut self, obj: Object) { + let ambiently_reachable_objects = match &mut self.capabilities[self.ambient_capability.0] { + CapabilityDef::Ambient { reachable_objects } => reachable_objects, + _ => unreachable!(), + }; + if ambiently_reachable_objects.insert(obj) { + self.change_log.push(Change::EscapeObjectToAmbient(obj)); + } + + // HACK(eddyb) the persistent/"sticky" behavior goes against + // state snapshotting / the change log, but this is safer + // for now, and the value is only read at the very end anyway. + self.objects[obj.0].keep_alive = true; + } + + fn unanalyzable_value_uses(&mut self, values: &[Value]) { + for &v in values { + let cap = match v { + Value::Const(_) => continue, + _ => match self.value_to_capability.get(&v) { + None => continue, + Some(&cap) => cap, + }, + }; + // FIXME(eddyb) this could be more fine-grained wrt the object graph. + match self.capabilities[cap.0] { + CapabilityDef::Ambient { .. } => unreachable!(), + CapabilityDef::WholeObject(obj) => { + self.escape_object_to_ambient(obj); + } + } + } + } + + fn flow_through_data_inst( + &mut self, + mut func_at_inst: FuncAtMut<'_, DataInst>, + parent_block: ControlNode, + ) { + let cx = self.cx; + + let data_inst = func_at_inst.position; + + let DataInstDef { attrs: _, form, inputs } = func_at_inst.reborrow().def(); + + let DataInstFormDef { kind, output_types } = &cx[*form]; + + self.flow_into_values(inputs); + + match *kind { + // FIXME(eddyb) turn this into uses of the declarative metadata. + DataInstKind::QPtr(QPtrOp::FuncLocalVar(_)) => { + assert!(inputs.len() <= 1); + + let new_obj_def = ObjectDef { + keep_alive: false, + known_state: Some(self.intern_object_state(ObjectStateDef::AfterInst { + prev: None, + inst: data_inst, + parent_block, + })), + }; + // HACK(eddyb) allocate new object + let new_obj = { + let idx = self.objects.len(); + self.objects.push(new_obj_def); + Object(idx) + }; + + self.change_log + .push(Change::AllocateNewObject { new_object: new_obj, alloc_inst: data_inst }); + + let new_cap_def = CapabilityDef::WholeObject(new_obj); + // HACK(eddyb) allocate new capability + let new_cap = { + let idx = self.capabilities.len(); + self.capabilities.push(new_cap_def); + Capability(idx) + }; + + self.value_to_capability + .insert(Value::DataInstOutput { inst: data_inst, output_idx: 0 }, new_cap); + } + + // FIXME(eddyb) turn this into uses of the declarative metadata. + DataInstKind::QPtr(QPtrOp::Store { .. }) => { + assert_eq!(inputs.len(), 2); + let dst_ptr = inputs[0]; + let stored_value = inputs[1]; + + if let Some(cap) = self.value_to_capability.get(&dst_ptr) { + if let CapabilityDef::WholeObject(obj) = self.capabilities[cap.0] { + // FIXME(eddyb) if the object is escaped i.e. ambiently + // reachable, this could still be considered to soundly + // overwrite any other state, because it's unsynchronized, + // so even if concurrent accesses could be performed, + // data races are still UB so we can assume unaliasing. + + let new_state = self.intern_object_state(ObjectStateDef::AfterInst { + prev: self.objects[obj.0].known_state, + inst: data_inst, + parent_block, + }); + + // FIXME(eddyb) avoid redundant `change_log` pushes, + // it should be possible to keep track of snapshot + // "watermarks", and "`Change` slot for this snapshot", + // within `ObjectDef`, to reuse `Change` slots. + self.change_log.push(Change::ObjectState { + object: obj, + prev_known_state: self.objects[obj.0].known_state.replace(new_state), + }); + + // Only visit the value input, not the destination pointer. + self.unanalyzable_value_uses(&[stored_value]); + + return; + } + } + } + + // FIXME(eddyb) reify this so the search is only triggered by a use + // of the result of the load, not the load itself! + // FIXME(eddyb) this has no caching, could consider at least merge + // caching (to avoid adding duplicate region outputs etc.), and/or + // deduplicating the `qptr.load` itself if the object hadn't changed. + DataInstKind::QPtr(QPtrOp::Load { offset }) => { + assert_eq!(inputs.len(), 1); + let src_ptr = inputs[0]; + + let known_state = self.value_to_capability.get(&src_ptr).and_then(|cap| match self + .capabilities[cap.0] + { + CapabilityDef::WholeObject(obj) => self.objects[obj.0].known_state, + _ => None, + }); + + // HACK(eddyb) avoid accumulating orphan side-effects while + // trying to fix-point loops, if some things aren't monotonic. + // FIXME(eddyb) this is not perfect, because `flow_into_values` + // may have already used this by the time it's invalidated, so + // loops should ideally not be mutating the IR until outermost + // loops are complete (or if everything can be made monotonic). + self.replaceable_insts.remove(&data_inst); + + if let Some(known_state) = known_state { + let const_undef = |ty| { + Value::Const(cx.intern(ConstDef { + attrs: AttrSet::default(), + ty, + kind: ConstKind::Undef, + })) + }; + + let access_type = output_types[0]; + let loaded_value = self.query_past_value( + func_at_inst.reborrow().at(()), + known_state, + access_type, + |func_at_write_inst| { + let func = func_at_write_inst.at(()); + let write_inst_def = func_at_write_inst.def(); + match (&cx[write_inst_def.form].kind, &write_inst_def.inputs[..]) { + (DataInstKind::QPtr(QPtrOp::FuncLocalVar(_)), []) => { + Ok(QueryStep::Done(const_undef(access_type))) + } + (DataInstKind::QPtr(QPtrOp::FuncLocalVar(_)), &[init_value]) + if offset == 0 + && func.at(init_value).type_of(cx) == access_type => + { + Ok(QueryStep::Done(init_value)) + } + + // FIXME(eddyb) move this so it can do layout, + // and filter by disjoint access ranges. + ( + &DataInstKind::QPtr(QPtrOp::Store { offset: store_offset }), + &[_, stored_value], + ) if offset == store_offset + && func_at_write_inst.at(stored_value).type_of(cx) + == access_type => + { + Ok(QueryStep::Done(stored_value)) + } + + _ => Err(UnknowableValue), + } + }, + ); + if let Ok(loaded_value) = loaded_value { + self.replaceable_insts.insert( + data_inst, + ([Some(loaded_value)].into_iter().collect(), parent_block), + ); + return; + } + } + } + + _ => {} + } + + self.unanalyzable_value_uses(&func_at_inst.def().inputs); + + // FIXME(eddyb) have a way to describe instructions as side-effect-free. + self.unanalyzable_effects(); + } + + fn flow_through_control_region(&mut self, mut func_at_region: FuncAtMut<'_, ControlRegion>) { + let mut children = func_at_region.reborrow().at_children().into_iter(); + while let Some(func_at_control_node) = children.next() { + self.flow_through_control_node(func_at_control_node); + } + + let ControlRegionDef { inputs: _, children: _, outputs } = func_at_region.def(); + self.flow_into_values(outputs); + self.unanalyzable_value_uses(outputs); + } + + fn flow_through_control_node(&mut self, func_at_control_node: FuncAtMut<'_, ControlNode>) { + let control_node = func_at_control_node.position; + + // FIXME(eddyb) is this a good convention? + let mut func = func_at_control_node.at(()); + + match &mut func.reborrow().at(control_node).def().kind { + &mut ControlNodeKind::Block { insts } => { + let mut func_at_inst_iter = func.at(insts).into_iter(); + while let Some(func_at_inst) = func_at_inst_iter.next() { + self.flow_through_data_inst(func_at_inst, control_node); + } + } + ControlNodeKind::Select { kind: _, scrutinee, cases } => { + self.flow_into_values(slice::from_mut(scrutinee)); + self.unanalyzable_value_uses(&[*scrutinee]); + + let num_cases = cases.len(); + + // FIXME(eddyb) represent the list of child regions without having them + // in a `Vec` (or `SmallVec`), which requires workarounds like this. + let get_case = |func: FuncAtMut<'_, ()>, i| match &func.at(control_node).def().kind + { + ControlNodeKind::Select { cases, .. } => cases[i], + _ => unreachable!(), + }; + + // HACK(eddyb) degenerate `Select`s do not actually need merges. + if num_cases <= 1 { + if num_cases == 1 { + let case = get_case(func.reborrow(), 0); + self.flow_through_control_region(func.at(case)); + } + return; + } + + // HACK(eddyb) this is how we can both roll back changes to + // `ObjectDef`s' `known_state`, and know which objects changed + // in the first place (to merge their changed states, together). + let change_log_start = self.change_log.len(); + + let mut all_escaped_objects = FxIndexSet::default(); + let mut obj_to_per_case_states = FxIndexMap::>::default(); + for case_idx in 0..num_cases { + let case = get_case(func.reborrow(), case_idx); + self.flow_through_control_region(func.reborrow().at(case)); + + // NOTE(eddyb) we traverse the change log forwards, as we + // already have a way to determine whether we've seen any + // changes for each object, and only the oldest change + // is needed to roll back the object to its original state. + for change in self.change_log.drain(change_log_start..) { + match change { + Change::AllocateNewObject { .. } => { + // FIXME(eddyb) should this be banned in some way? + } + Change::EscapeObjectToAmbient(obj) => { + let ambiently_reachable_objects = + match &mut self.capabilities[self.ambient_capability.0] { + CapabilityDef::Ambient { reachable_objects } => { + reachable_objects + } + _ => unreachable!(), + }; + assert!(ambiently_reachable_objects.remove(&obj)); + all_escaped_objects.insert(obj); + } + Change::ObjectState { object: obj, prev_known_state } => { + let original_state = prev_known_state; + + let per_case_states = + obj_to_per_case_states.entry(obj).or_insert_with(|| { + let mut per_case_states = + SmallVec::with_capacity(num_cases); + + // This case may be the first to change this + // object - thankfully we know the original + // state (which will be common across all cases). + per_case_states + .extend((0..case_idx).map(|_| original_state)); + + per_case_states + }); + + if per_case_states.len() <= case_idx { + let new_state = mem::replace( + &mut self.objects[obj.0].known_state, + original_state, + ); + per_case_states.push(new_state); + } + assert_eq!(per_case_states.len() - 1, case_idx); + } + } + } + + // Some objects may only have been changed in previous cases. + for (&object, per_case_states) in &mut obj_to_per_case_states { + if per_case_states.len() <= case_idx { + per_case_states.push(self.objects[object.0].known_state); + assert_eq!(per_case_states.len() - 1, case_idx); + } + } + } + + for escaped_obj in all_escaped_objects { + self.escape_object_to_ambient(escaped_obj); + } + + // Objects changed in at least one case can now be merged, + // by creating `ObjectStateDef::AfterSelect`s for all of them. + for (obj, per_case_states) in obj_to_per_case_states { + assert_eq!(per_case_states.len(), num_cases); + + // HACK(eddyb) do not create a new state if all cases agree. + let s0 = per_case_states[0]; + let merged_state = if per_case_states[1..].iter().all(|&s| s == s0) { + s0 + } else { + // FIXME(eddyb) bail out of storing `None::`s sooner. + per_case_states.into_iter().collect::>().map(|per_case_states| { + self.intern_object_state(ObjectStateDef::AfterSelect { + select_node: control_node, + state_before_select: self.objects[obj.0].known_state, + per_case_states, + }) + }) + }; + let prev_known_state = + mem::replace(&mut self.objects[obj.0].known_state, merged_state); + self.change_log.push(Change::ObjectState { object: obj, prev_known_state }); + } + } + ControlNodeKind::Loop { initial_inputs, body, repeat_condition: _ } => { + self.flow_into_values(initial_inputs); + self.unanalyzable_value_uses(initial_inputs); + + // HACK(eddyb) this may get expensive, as most objects may not + // end up being changed at all, but further optimizing this while + // remaining sound is an exercise for another day. + struct LoopObjectState { + initial: Option, + before_body: ObjectState, + after_body: Option, + } + let mut loop_object_states: FxIndexMap = + (0..self.objects.len()) + .map(Object) + .map(|obj| { + ( + obj, + LoopObjectState { + initial: self.objects[obj.0].known_state, + before_body: self.intern_object_state( + ObjectStateDef::BeforeLoopBody { + loop_node: control_node, + object: obj, + }, + ), + after_body: None, + }, + ) + }) + .collect(); + + let body = *body; + + let body_change_log_start = self.change_log.len(); + + // FIXME(eddyb) this risks monotonic runaway (and e.g. OOM), + // but ideally all that happens is taking the fixpoint of the + // loop body by repeatedly (re)processing it, with interning of + // `ObjectState`s providing some of the "saturating" behavior. + let mut states_changed; + loop { + states_changed = false; + + let replaceable_insts_len = self.replaceable_insts.len(); + + // Reset to the start of the loop body (like region inputs). + for (&obj, state) in &loop_object_states { + self.objects[obj.0].known_state = Some(state.before_body); + } + + self.flow_through_control_region(func.reborrow().at(body)); + + // NOTE(eddyb) the repeat condition is technically part of + // the loop body, as if it were an extra region output. + let repeat_condition = match &mut func.reborrow().at(control_node).def().kind { + ControlNodeKind::Loop { repeat_condition, .. } => repeat_condition, + _ => unreachable!(), + }; + self.flow_into_values(slice::from_mut(repeat_condition)); + self.unanalyzable_value_uses(&[*repeat_condition]); + + // HACK(eddyb) reduce work on future iterations by pruning + // entries which don't actually correspond to any changes. + loop_object_states.retain(|obj, state| { + let new_after_body = self.objects[obj.0].known_state; + if new_after_body != state.after_body { + state.after_body = new_after_body; + states_changed = true; + + if let (Some(initial), Some(after_body)) = + (state.initial, state.after_body) + { + self.loop_object_state_initial_and_after_body + .insert(state.before_body, (initial, after_body)); + } else { + self.loop_object_state_initial_and_after_body + .remove(&state.before_body); + } + } + state.after_body != Some(state.before_body) + }); + + // FIXME(eddyb) this isn't even needed for `ObjectState` + // changes, because those are handled specially above/below. + let mut attempted_dynamic_object_allocation = false; + for change in self.change_log.drain(body_change_log_start..) { + match change { + Change::AllocateNewObject { alloc_inst, .. } => { + func.reborrow().at(alloc_inst).def().attrs.push_diag( + self.cx, + Diag::bug(["loops cannot allocate objects".into()]), + ); + attempted_dynamic_object_allocation = true; + break; + } + Change::EscapeObjectToAmbient(_) => { + // NOTE(eddyb) not undone because we want to + // *accumulate* escapes, not reduce them. + // FIXME(eddyb) this is bad because unlike the + // regular state tracking, this starts maximally + // permissive and becomes more restrictive with + // more iterations (as escapes are discovered). + // HACK(eddyb) for now, `replaceable_insts` slots + // are manually cleared in `flow_through_data_inst` + // but it should either use the change log or + // make escape tracking conservative. + states_changed = true; + } + Change::ObjectState { object: obj, prev_known_state: _ } => { + // HACK(eddyb) sanity check, just in check the + // `retain` above is overeager. + assert!(loop_object_states.contains_key(&obj)); + } + } + } + + if attempted_dynamic_object_allocation + || !states_changed && replaceable_insts_len == self.replaceable_insts.len() + { + break; + } + } + + // NOTE(eddyb) this is needed to preserve the original states of + // changed objects - it could've been earlier, but that may bloat + // the change log (before pruning unchanged objects, at least). + for (obj, state) in loop_object_states { + self.change_log + .push(Change::ObjectState { object: obj, prev_known_state: state.initial }); + } + } + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 83403c74..c292610a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -154,6 +154,7 @@ // (i.e. using inner doc comments). pub mod cfg; mod context; +pub mod flow; pub mod func_at; pub mod print; pub mod transform; @@ -1073,7 +1074,7 @@ pub enum DataInstKind { /// avoid them because of their (negative) impact on analyses and transforms, /// with their main vestigial purpose being to encode multiple return values /// from functions, which can be done more directly in other IRs (and SPIR-T) -#[derive(Copy, Clone, PartialEq, Eq)] +#[derive(Copy, Clone, PartialEq, Eq, Hash)] pub enum Value { Const(Const), diff --git a/src/passes/qptr.rs b/src/passes/qptr.rs index 96284365..24524afa 100644 --- a/src/passes/qptr.rs +++ b/src/passes/qptr.rs @@ -69,6 +69,11 @@ pub fn partition_and_propagate(module: &mut Module, layout_config: &qptr::Layout func_def_body, ); + if true { + crate::flow::flow_func(cx.clone(), func_def_body); + break; + } + let report = qptr::simplify::propagate_contents_of_local_vars_in_func( cx.clone(), layout_config,