diff --git a/Project.toml b/Project.toml index 7c385ec..3c32cac 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "FunctionProperties" uuid = "f62d2435-5019-4c03-9749-2d4c77af0cbc" -version = "0.1.7" +version = "0.1.8" authors = ["SciML"] [deps] diff --git a/src/FunctionProperties.jl b/src/FunctionProperties.jl index c20f4fc..06dcfa8 100644 --- a/src/FunctionProperties.jl +++ b/src/FunctionProperties.jl @@ -5,6 +5,85 @@ using Core: GotoIfNot # Backstop against pathological recursion depth; real call trees that matter here are shallow. const RECURSION_LIMIT = 256 +# ---- experimental: constant-propagation-aware recursion ------------------------------------ +# +# When recursing into a call, ordinary analysis widens every argument to its type, discarding +# constants. A branch decided by a constant argument (e.g. selecting a buffer by a literal index +# inside a parameter container) therefore looks value-dependent even though every real call site +# folds it. If we instead preserve the `Core.Const` argument and re-run *inference* (no optimizer, +# so no library/structural branches are inlined into view), such branches fold to `Core.Const` +# conditions that [`_is_const_gotoifnot`](@ref) already skips -- generalizing [`is_leaf_sig`](@ref) +# without any per-container knowledge. +# +# This relies on `Base.Compiler` (`Core.Compiler`) internals whose API churns across Julia versions +# (the `InferenceState` construction and inferred-source extraction already differ between 1.12 and +# 1.13), so it is *functionally* capability-gated -- see `_const_prop_capable` -- and OFF by +# default. Enable with [`enable_const_prop!`](@ref). +const _CC = isdefined(Base, :Compiler) ? Base.Compiler : Core.Compiler + +const _CONST_PROP = Ref(false) +# `nothing` until the functional probe has run; then `true`/`false`. +const _CONST_PROP_CAPABLE = Ref{Union{Nothing, Bool}}(nothing) + +# Fixture with a branch decided purely by a constant integer index -- the shape the feature must +# fold. Used only by the capability probe. +struct _ProbeContainer + a::Int + b::Int +end +@generated function _probe_indexed(x::_ProbeContainer, idx::Int) + quote + if idx == 1 + return x.a + else + return x.b + end + end +end + +# Verify, on the running Julia, that constant inference actually folds a constant-decided branch: +# the constant-index call must come back branch-free while the widened-index call must not. If the +# compiler internals we depend on have shifted shape, this returns `false` and the feature stays +# inert (so behaviour is identical to the plain type recursion). Probed once, then cached. +function _probe_const_prop() + sig = Tuple{typeof(_probe_indexed), _ProbeContainer, Int} + folded = _const_infer_src(sig, Any[Core.Const(_probe_indexed), _ProbeContainer, Core.Const(1)]) + widened = _const_infer_src(sig, Any[Core.Const(_probe_indexed), _ProbeContainer, Int]) + folded isa Core.CodeInfo || return false + widened isa Core.CodeInfo || return false + return _count_nonconst_gotoifnot(folded) == 0 && _count_nonconst_gotoifnot(widened) > 0 +end + +function _const_prop_capable() + v = _CONST_PROP_CAPABLE[] + if v === nothing + v = try + _probe_const_prop() + catch + false + end + _CONST_PROP_CAPABLE[] = v + end + return v +end + +""" + enable_const_prop!(on::Bool = true) -> Bool + +Experimental. Toggle constant-propagation-aware recursion in [`hasbranching`](@ref). When on (and +the running Julia's compiler internals still fold a constant-decided branch, verified functionally +by [`_const_prop_capable`](@ref)), a call with constant arguments is re-inferred with those +constants preserved, so value-independent branches decided by a constant (e.g. selecting a buffer +by a literal index) fold away instead of being reported. Off by default because it depends on +compiler internals. Returns the effective state. +""" +function enable_const_prop!(on::Bool = true) + _CONST_PROP[] = on + return _const_prop_active() +end + +_const_prop_active() = _CONST_PROP[] && _const_prop_capable() + """ is_leaf(f, args...) -> Bool @@ -96,17 +175,100 @@ function _hasbranching(@nospecialize(sig), seen, depth) # Generated functions that were not expanded come back as `Method`, not `CodeInfo`; # there is no body to scan, so treat them as leaves. ci isa Core.CodeInfo || continue - for stmt in ci.code - if isa(stmt, GotoIfNot) - _is_const_gotoifnot(stmt, ci) || return true - elseif _recurse_call(stmt, ci, seen, depth) - return true - end + _scan_codeinfo(ci, seen, depth) && return true + end + return false +end + +function _scan_codeinfo(ci, seen, depth) + for stmt in ci.code + if isa(stmt, GotoIfNot) + _is_const_gotoifnot(stmt, ci) || return true + elseif _recurse_call(stmt, ci, seen, depth) + return true end end return false end +# Constant-argument recursion (experimental, see [`enable_const_prop!`](@ref)): re-infer the callee +# with the constant lattice elements preserved so branches decided by a constant argument fold to +# `Core.Const` conditions. Inference is run *without* the optimizer, so no library/structural +# branches are inlined into view. Falls back to the plain type recursion whenever the constant +# inference is unavailable or fails. +function _hasbranching_const(@nospecialize(sig), argtypes, seen, depth) + depth > RECURSION_LIMIT && return false + key = (sig, _const_key(argtypes)) + key in seen && return false + push!(seen, key) + src = _const_infer_src(sig, argtypes) + src isa Core.CodeInfo || return _hasbranching(sig, seen, depth) + return _scan_codeinfo(src, seen, depth) +end + +_const_key(argtypes) = map(x -> x isa Core.Const ? (true, x.val) : (false, x), argtypes) + +# Run inference on `sig` with the given argument lattice (some `Core.Const`) preserved, and return +# the inferred (unoptimized) `CodeInfo`, or `nothing` if the compiler internals do not cooperate. +# The `InferenceState` construction and the inferred-source location differ across Julia versions: +# 1.12 accepts `InferenceState(result, cache_mode, interp)` and exposes the body on `result.src`, +# while 1.13 wants the uninferred source passed explicitly and exposes the body on `frame.src`. We +# try the explicit-source form first (works on both) with the non-caching `:volatile` mode, then +# fall back, and read whichever of `frame.src`/`result.src` is a `CodeInfo`. Any shape we don't +# recognise simply yields `nothing`, and the functional probe (`_const_prop_capable`) keeps the +# whole feature inert on such versions. +function _const_infer_src(@nospecialize(sig), argtypes) + m = try + Base.which(sig) + catch + return nothing + end + mi = try + Base.specialize_method(m, sig, Core.svec()) + catch + return nothing + end + overridden = BitVector(x isa Core.Const for x in argtypes) + src0 = try + _CC.retrieve_code_info(mi, Base.get_world_counter()) + catch + nothing + end + # A fresh `InferenceResult`/`InferenceState` per attempt: an `InferenceResult` cannot be + # re-inferred once used. + for build in ( + interp -> src0 isa Core.CodeInfo ? + _CC.InferenceState(_new_result(mi, argtypes, overridden), src0, :volatile, interp) : + nothing, + interp -> _CC.InferenceState(_new_result(mi, argtypes, overridden), :volatile, interp), + ) + src = try + interp = _CC.NativeInterpreter() + frame = build(interp) + frame === nothing && continue + _CC.typeinf(interp, frame) + _inferred_src(frame) + catch + nothing + end + src isa Core.CodeInfo && return src + end + return nothing +end + +_new_result(mi, argtypes, overridden) = _CC.InferenceResult(mi, Any[argtypes...], overridden) + +function _inferred_src(frame) + if isdefined(frame, :src) && getfield(frame, :src) isa Core.CodeInfo + return getfield(frame, :src) + end + r = getfield(frame, :result) + return (r isa _CC.InferenceResult && r.src isa Core.CodeInfo) ? r.src : nothing +end + +_count_nonconst_gotoifnot(ci::Core.CodeInfo) = + count(s -> isa(s, GotoIfNot) && !_is_const_gotoifnot(s, ci), ci.code) + # A `GotoIfNot` whose condition type inference has *proven* constant is a compile-time branch, # not a value-dependent one: e.g. an `x isa T` test on a concretely-typed field (the SciML # `ODEFunction` wrapper) or the device/type-introspection dispatch inside ML library layers @@ -143,7 +305,9 @@ function _recurse_call(@nospecialize(stmt), ci, seen, depth) getfield(mi, :def).specTypes : nothing ) callsig === nothing && return false - return _recurse_sig(callsig, nothing, seen, depth) + _, fval = _resolve_callee(call.args[2], ci) + arglat = Any[_arg_lattice(a, ci) for a in @view call.args[3:end]] + return _recurse_sig(callsig, fval, arglat, seen, depth) end Meta.isexpr(call, :call) || return false @@ -152,8 +316,8 @@ function _recurse_call(@nospecialize(stmt), ci, seen, depth) end ftype, fval = _resolve_callee(call.args[1], ci) ftype === nothing && return false - argtypes = Any[_value_type(a, ci) for a in @view call.args[2:end]] - return _recurse_sig(Tuple{ftype, argtypes...}, fval, seen, depth) + arglat = Any[_arg_lattice(a, ci) for a in @view call.args[2:end]] + return _recurse_sig(Tuple{ftype, (_lat_type(x) for x in arglat)...}, fval, arglat, seen, depth) end _is_apply(@nospecialize(f)) = @@ -185,10 +349,11 @@ function _recurse_apply(call, ci, seen, depth) return false end end - return _recurse_sig(Tuple{ftype, argtypes...}, fval, seen, depth) + # Splatted arguments are recovered from tuple element *types*; constants are not available here. + return _recurse_sig(Tuple{ftype, argtypes...}, fval, Any[argtypes...], seen, depth) end -function _recurse_sig(@nospecialize(callsig), @nospecialize(fval), seen, depth) +function _recurse_sig(@nospecialize(callsig), @nospecialize(fval), arglat, seen, depth) # Honor user `is_leaf` overrides when the concrete function value is recoverable. fval !== nothing && is_leaf(fval) && return false # Signature-level overrides: exemptions that depend on the argument types. @@ -199,9 +364,48 @@ function _recurse_sig(@nospecialize(callsig), @nospecialize(fval), seen, depth) return false end _is_library_method(m) && return false + if _const_prop_active() && any(x -> x isa Core.Const, arglat) + funclat = fval !== nothing ? Core.Const(fval) : _first_param(callsig) + return _hasbranching_const(callsig, Any[funclat, arglat...], seen, depth + 1) + end return _hasbranching(callsig, seen, depth + 1) end +_first_param(@nospecialize(sig)) = + (sig isa DataType && !isempty(sig.parameters)) ? sig.parameters[1] : Any +_lat_type(@nospecialize(x)) = x isa Core.Const ? Core.Typeof(x.val) : x + +# Argument lattice element: a `Core.Const` when the argument is a compile-time constant, otherwise +# the widened type. Preserving the `Core.Const` is what lets a constant index survive the recursion +# boundary so `_hasbranching_const` can fold the branch it decides. +function _arg_lattice(@nospecialize(a), ci) + if a isa Core.SSAValue + t = ci.ssavaluetypes[a.id] + return t isa Core.Const ? t : _widen(t) + elseif a isa Core.Argument + st = ci.slottypes + st === nothing && return Any + t = st[a.n] + return t isa Core.Const ? t : _widen(t) + elseif a isa Core.SlotNumber + st = ci.slottypes + st === nothing && return Any + t = st[a.id] + return t isa Core.Const ? t : _widen(t) + elseif a isa GlobalRef + return (isdefined(a.mod, a.name) && isconst(a.mod, a.name)) ? + Core.Const(getglobal(a.mod, a.name)) : Any + elseif a isa QuoteNode + return Core.Const(a.value) + elseif a isa Expr || a isa Core.GotoNode || a isa GotoIfNot || + a isa Core.NewvarNode || a isa Core.ReturnNode + return Any + else + # Raw literal constant embedded in the IR (e.g. an `Int` index). + return Core.Const(a) + end +end + # Library code (`Base`, `Core`, stdlibs) is treated as a leaf: its branches are structural or # compile-time, not the value-dependent user logic `hasbranching` is meant to surface. function _is_library_method(m::Method) @@ -254,6 +458,6 @@ _widen(@nospecialize t) = t isa Core.PartialStruct ? t.typ : isa(t, Type) ? t : Any -export hasbranching, is_leaf, is_leaf_sig +export hasbranching, is_leaf, is_leaf_sig, enable_const_prop! end diff --git a/test/core_tests.jl b/test/core_tests.jl index 376e831..ce284ac 100644 --- a/test/core_tests.jl +++ b/test/core_tests.jl @@ -159,3 +159,47 @@ rhs_with_plumbing(u, p, t) = select_buffer(p, 1) * u @test FunctionProperties.hasbranching(rhs_with_plumbing, 1.0, TwoBuffers(1.0, 2.0), 0.0) FunctionProperties.is_leaf_sig(::Type{<:Tuple{typeof(select_buffer), TwoBuffers, Vararg}}) = true @test !FunctionProperties.hasbranching(rhs_with_plumbing, 1.0, TwoBuffers(1.0, 2.0), 0.0) + +# --------------------------------------------------------------------------------------------- +# Experimental: constant-propagation-aware recursion (`enable_const_prop!`). +# +# A branch decided by a *constant* argument (e.g. selecting a buffer by a literal index) is +# value-independent, but ordinary recursion widens the argument and reports it. With const-prop on, +# the callee is re-inferred with the constant preserved so such branches fold away. It stays +# conservative: a genuinely value-dependent branch, or a dynamic (non-constant) index, is still +# reported. Off by default, so it must not change any behavior unless explicitly enabled. +struct TwoBufferParams + a::Vector{Float64} + b::Vector{Float64} +end +@generated function pick_buffer(p::TwoBufferParams, idx::Int) + quote + if idx == 1 + return p.a + elseif idx == 2 + return p.b + else + throw(BoundsError(p, idx)) + end + end +end +cp_relu(x) = x > 0 ? x : zero(x) +rhs_const_index(p) = @inbounds pick_buffer(p, 1)[1] +rhs_dynamic_index(p, i) = @inbounds pick_buffer(p, i)[1] +rhs_real_branch(u, p) = cp_relu(u) + @inbounds pick_buffer(p, 1)[1] +tbp = TwoBufferParams([1.0], [2.0]) + +# Default: off -> the value-independent index branch is (conservatively) still reported. +@test !FunctionProperties.enable_const_prop!(false) +@test FunctionProperties.hasbranching(rhs_const_index, tbp) + +if FunctionProperties._const_prop_capable() + @test FunctionProperties.enable_const_prop!(true) + try + @test !FunctionProperties.hasbranching(rhs_const_index, tbp) # constant index folds + @test FunctionProperties.hasbranching(rhs_real_branch, 1.0, tbp) # genuine branch kept + @test FunctionProperties.hasbranching(rhs_dynamic_index, tbp, 1) # dynamic index: conservative + finally + FunctionProperties.enable_const_prop!(false) + end +end