diff --git a/Project.toml b/Project.toml index 7c385ec..32a42f2 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "FunctionProperties" uuid = "f62d2435-5019-4c03-9749-2d4c77af0cbc" -version = "0.1.7" +version = "1.0.0" authors = ["SciML"] [deps] diff --git a/docs/Project.toml b/docs/Project.toml index 40a63e3..edaad76 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -4,4 +4,4 @@ FunctionProperties = "f62d2435-5019-4c03-9749-2d4c77af0cbc" [compat] Documenter = "1" -FunctionProperties = "0.1.2" +FunctionProperties = "1" diff --git a/docs/make.jl b/docs/make.jl index 96f55d1..e74f9c6 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -28,6 +28,6 @@ makedocs( ) deploydocs( - repo = "github.com/SciML/MultiScaleArrays.jl.git"; + repo = "github.com/SciML/FunctionProperties.jl.git"; push_preview = true ) diff --git a/src/FunctionProperties.jl b/src/FunctionProperties.jl index c20f4fc..16a433f 100644 --- a/src/FunctionProperties.jl +++ b/src/FunctionProperties.jl @@ -5,6 +5,27 @@ using Core: GotoIfNot # Backstop against pathological recursion depth; real call trees that matter here are shallow. const RECURSION_LIMIT = 256 +# Refutations are only *started* this close to the root. A refutation that cannot bottom out +# within the remaining depth budget can never succeed, so starting one deep inside a +# constant-recursion tower only pays for a doomed re-descent -- and since failed refutations are +# (soundly) not memoized, a tower deeper than `RECURSION_LIMIT` otherwise re-descends once per +# level: O(limit^2) inference calls, which measured in the tens of minutes. Constant recursion +# that legitimately folds is shallow (a handful of levels); anything deeper is conservatively +# reported as branching. +const REFUTATION_DEPTH_LIMIT = 32 + +# `hasbranching` recurses through statically resolved calls. Ordinary analysis widens every argument +# to its type, which loses constants: a branch decided by a *constant* argument (e.g. selecting a +# buffer by a literal index inside a parameter container) then looks value-dependent even though +# every real call site folds it. When the running Julia's compiler cooperates, such a call is +# re-inferred with its `Core.Const` arguments preserved (no optimizer, so no library/structural +# branches are inlined into view) and the constant-decided branch folds to a `Core.Const` condition +# that `_is_const_gotoifnot` skips. This depends on `Base.Compiler`/`Core.Compiler` internals whose +# API changes across Julia versions, so it is *functionally* gated (see `_const_prop_capable`): it +# activates only where a probe confirms folding actually works, and otherwise the analysis falls +# back to the plain type recursion. +const _CC = isdefined(Base, :Compiler) ? Base.Compiler : Core.Compiler + """ is_leaf(f, args...) -> Bool @@ -20,27 +41,6 @@ FunctionProperties.is_leaf(::typeof(my_fn)) = true """ is_leaf(f, args...) = false -""" - is_leaf_sig(sig::Type{<:Tuple}) -> Bool - -Signature-level counterpart to [`is_leaf`](@ref), consulted while recursing through statically -resolved calls. `sig` is the call's `Tuple{typeof(f), argtypes...}`. Return `true` to treat the -call as branch-free and stop recursing into it. - -Use this (instead of `is_leaf`) when the exemption depends on the *argument types*, not just the -function. The motivating case is value-independent plumbing whose branch is on an index/type -rather than on traced values — e.g. selecting a buffer by integer index inside a parameter -container, where each real call site passes a literal index that constant-folds the branch away, -but the recursion only sees the widened argument type. - -## Example - -```julia -FunctionProperties.is_leaf_sig(::Type{<:Tuple{typeof(getindex), <:MyParamContainer, Vararg}}) = true -``` -""" -is_leaf_sig(@nospecialize(sig)) = false - """ hasbranching(f, x...) @@ -64,6 +64,10 @@ that branches living behind a non-inlined call boundary are still detected. Call are structural/compile-time rather than value-dependent user logic, and recursing into them (e.g. matrix multiply, broadcasting, `getindex` bounds checks) produces false positives. +Branches whose condition inference proves constant are ignored (they are not value-dependent), +and — where the compiler cooperates — a call with constant arguments is re-inferred with those +constants preserved so branches they decide fold away rather than being reported. + ## Customizing and Removing Functions from the Checks Some functions may produce false positives because their internal branches are compile-time @@ -77,44 +81,142 @@ FunctionProperties.is_leaf(::typeof(my_fn)) = true function hasbranching(f, x...) is_leaf(f, x...) && return false sig = Tuple{Core.Typeof(f), Core.Typeof.(x)...} - return _hasbranching(sig, Set{Any}(), 0) + return _hasbranching(sig, Set{Any}(), 0) != NOBRANCH end +# Scan results form a tri-state. `LIMITED` ("could be branching") is distinct from `BRANCH` so +# refutation is only ever attempted on a branch that was actually *seen*: a limit-tainted result +# would make any refutation fail (its scan exhausts the same budget), so attempting one only pays +# for a doomed re-descent -- on a deep distinct-signature tower (e.g. `Val{N}` recursion), once per +# level, which measured in the tens of minutes. +const NOBRANCH = 0x00 +const BRANCH = 0x01 +const LIMITED = 0x02 + +# `seen` serves two roles: cycle breaking for sigs on the current DFS path, and memoization of +# sigs proven branch-free. A `NOBRANCH` result is sound to memoize globally -- the scan uses +# widened argument types, and constants can only fold branches away, so a type-level `NOBRANCH` +# holds for every call site. Non-`NOBRANCH` results are NOT memoized: those sigs are popped before +# returning, because constant refutation may flip a `BRANCH` to branch-free at one call site while +# another call site of the same sig (with different or no constants) must still re-analyze and +# report the branch (memoizing them produced order-dependent false negatives), and `LIMITED` +# depends on the remaining depth budget. function _hasbranching(@nospecialize(sig), seen, depth) - depth > RECURSION_LIMIT && return false - sig in seen && return false + depth > RECURSION_LIMIT && return LIMITED + r = _unwrap_wrapper(sig, seen, depth) + r === nothing || return r + sig in seen && return NOBRANCH push!(seen, sig) + # If the *entry* IR cannot be obtained -- e.g. reflection is restricted because we are running + # inside a generated-function expansion -- the safe answer is "could be branching", not + # "assume a leaf": a silent branch-free here returned false negatives to generators that + # consulted `hasbranching` while expanding. For a *nested* callee whose IR is unobtainable + # even though `which` resolved it, the leaf treatment is kept: that is the same tier as the + # other unresolvable-callee give-ups, and on older Julia versions some library-adjacent + # signatures legitimately fail reflection mid-recursion. results = try Base.code_typed_by_type(sig; optimize = false) catch - return false + depth == 0 || return NOBRANCH + delete!(seen, sig) + return LIMITED end + scanned_any = false for pair in results ci = first(pair) # Generated functions that were not expanded come back as `Method`, not `CodeInfo`; # there is no body to scan, so treat them as leaves. ci isa Core.CodeInfo || continue - for stmt in ci.code - if isa(stmt, GotoIfNot) - _is_const_gotoifnot(stmt, ci) || return true - elseif _recurse_call(stmt, ci, seen, depth) - return true - end + scanned_any = true + r = _scan_codeinfo(ci, seen, depth) + if r != NOBRANCH + delete!(seen, sig) + return r end end - return false + # Nothing scannable at the *entry* -- no matching methods in the tables (an opaque closure) or + # only unexpandable generated bodies: same policy as an unobtainable entry IR, "could be + # branching". Mid-recursion the leaf treatment stands (`which` resolved the callee; empty or + # `Method`-only results there keep the long-standing give-up tier). + if depth == 0 && !scanned_any + delete!(seen, sig) + return LIMITED + end + return NOBRANCH +end + +# Base's callable wrapper structs (`ComposedFunction`, `Base.Fix`/`Fix1`/`Fix2`) delegate to the +# functions they capture through Base-owned helper methods (kwargs bodies, tuple plumbing), so the +# library-leaf boundary would swallow a user branch hidden inside the wrapper -- e.g. an ODE +# right-hand side written as `relu ∘ layer` silently reported branch-free. Known wrappers are +# unwrapped structurally into component signatures, each routed through the normal call policy: +# a Base component (`sin ∘ f`) stays a library leaf, a user component is scanned. Returns +# `nothing` when `sig` is not a recognized wrapper call. +function _unwrap_wrapper(@nospecialize(sig), seen, depth) + sig isa DataType || return nothing + params = collect(sig.parameters) + isempty(params) && return nothing + ft = params[1] + ft isa DataType || return nothing + argts = params[2:end] + if ft <: ComposedFunction && length(ft.parameters) == 2 + O, I = ft.parameters + inner = Tuple{I, argts...} + inner_res = _recurse_sig(inner, nothing, Any[argts...], seen, depth) + inner_res == NOBRANCH || return inner_res + rt = _return_type_of(inner) + return _recurse_sig(Tuple{O, rt}, nothing, Any[rt], seen, depth) + end + if isdefined(Base, :Fix) && ft <: Base.Fix && length(ft.parameters) == 3 + N, F, T = ft.parameters + N isa Int || return nothing + N - 1 <= length(argts) || return nothing + inner = Any[argts...] + insert!(inner, N, T) + return _recurse_sig(Tuple{F, inner...}, nothing, inner, seen, depth) + end + if !isdefined(Base, :Fix) && ft <: Base.Fix1 && length(ft.parameters) == 2 + F, T = ft.parameters + return _recurse_sig(Tuple{F, T, argts...}, nothing, Any[T, argts...], seen, depth) + end + if !isdefined(Base, :Fix) && ft <: Base.Fix2 && length(ft.parameters) == 2 + F, T = ft.parameters + return _recurse_sig(Tuple{F, argts..., T}, nothing, Any[argts..., T], seen, depth) + end + return nothing +end + +function _return_type_of(@nospecialize(sig)) + return try + rs = Base.code_typed_by_type(sig; optimize = false) + isempty(rs) ? Any : _widen(reduce(typejoin, Any[last(pair) for pair in rs])) + catch + Any + end +end + +function _scan_codeinfo(ci, seen, depth) + for stmt in ci.code + if isa(stmt, GotoIfNot) + _is_const_gotoifnot(stmt, ci) || return BRANCH + else + r = _recurse_call(stmt, ci, seen, depth) + r != NOBRANCH && return r + end + end + return NOBRANCH end # A `GotoIfNot` whose condition type inference has *proven* constant is a compile-time branch, # not a value-dependent one: e.g. an `x isa T` test on a concretely-typed field (the SciML -# `ODEFunction` wrapper) or the device/type-introspection dispatch inside ML library layers -# (SciML/FunctionProperties.jl#46). Such a branch can never be taken differently under a tracing -# AD, so it is not the branching `hasbranching` is meant to surface. A condition that is a literal -# `true`/`false` written directly into the IR is deliberately *not* skipped: that is a genuine -# syntactic branch in user code (e.g. `true ? a : b`). Only conditions inference resolved to a -# `Core.Const` value are dropped; anything we cannot positively prove constant is kept. +# `ODEFunction` wrapper), the device/type-introspection dispatch inside ML library layers +# (SciML/FunctionProperties.jl#46), or a branch a constant argument folded via the constant-argument +# recursion below. Such a branch can never be taken differently under a tracing AD. A condition that +# is a literal `true`/`false` written directly into the IR is deliberately *not* skipped: that is a +# genuine syntactic branch in user code (e.g. `true ? a : b`). Only conditions inference resolved to +# a `Core.Const` value are dropped; anything we cannot positively prove constant is kept. function _is_const_gotoifnot(stmt::GotoIfNot, ci) cond = stmt.cond t = if cond isa Core.SSAValue @@ -130,8 +232,9 @@ function _is_const_gotoifnot(stmt::GotoIfNot, ci) return t isa Core.Const end -# Inspect a single IR statement: if it is a statically resolvable call into a non-library -# method, recurse into that method's IR. Returns `true` if a branch is found downstream. +# Inspect a single IR statement: if it is a statically resolvable call into a non-library method, +# recurse into that method (with any constant arguments preserved). Returns `true` if a branch is +# found downstream. function _recurse_call(@nospecialize(stmt), ci, seen, depth) call = Meta.isexpr(stmt, :(=)) ? stmt.args[2] : stmt @@ -142,66 +245,290 @@ function _recurse_call(@nospecialize(stmt), ci, seen, depth) isdefined(mi, :def) && getfield(mi, :def) isa Core.MethodInstance ? getfield(mi, :def).specTypes : nothing ) - callsig === nothing && return false - return _recurse_sig(callsig, nothing, seen, depth) + callsig === nothing && return NOBRANCH + _, fval = _resolve_callee(call.args[2], ci) + arglat = Any[_arg_lattice(a, ci) for a in @view call.args[3:end]] + return _recurse_sig(callsig, fval, arglat, seen, depth) end - Meta.isexpr(call, :call) || return false + Meta.isexpr(call, :call) || return NOBRANCH if _is_apply(call.args[1]) return _recurse_apply(call, ci, seen, depth) end ftype, fval = _resolve_callee(call.args[1], ci) - ftype === nothing && return false - argtypes = Any[_value_type(a, ci) for a in @view call.args[2:end]] - return _recurse_sig(Tuple{ftype, argtypes...}, fval, seen, depth) + ftype === nothing && return NOBRANCH + arglat = Any[_arg_lattice(a, ci) for a in @view call.args[2:end]] + return _recurse_sig(Tuple{ftype, (_lat_type(x) for x in arglat)...}, fval, arglat, seen, depth) end _is_apply(@nospecialize(f)) = f isa GlobalRef && f.mod === Core && (f.name === :_apply_iterate || f.name === :_apply) -# A splatted call `g(a, bs...)` lowers to `Core._apply_iterate(iter, g, groups...)` (or, on -# older lowerings, `Core._apply(g, groups...)`). The real callee `g` is therefore an *argument* -# of a `Core` builtin, so the plain `:call` path would resolve the callee to `_apply_iterate`, -# treat it as library, and dead-end — missing every branch behind the forwarder. SciML/MTK RHS -# objects are exactly such forwarders (`ODEFunction` -> `GeneratedFunctionWrapper` -> -# `RuntimeGeneratedFunction` -> `generated_callfunc`, each `f(args...)`), so the generated body's -# branches only become reachable by following the apply through to `g`. The splatted groups are -# the actual positional arguments; recover their element types from the (concrete) tuple types so -# the right method specialization is selected downstream. +# A splatted call `g(a, bs...)` lowers to `Core._apply_iterate(iter, g, groups...)` (or, on older +# lowerings, `Core._apply(g, groups...)`). The real callee `g` is therefore an *argument* of a +# `Core` builtin, so the plain `:call` path would resolve the callee to `_apply_iterate`, treat it +# as library, and dead-end — missing every branch behind the forwarder. SciML/MTK RHS objects are +# exactly such forwarders (`ODEFunction` -> `GeneratedFunctionWrapper` -> `RuntimeGeneratedFunction` +# -> `generated_callfunc`, each `f(args...)`), so the generated body's branches only become +# reachable by following the apply. The splatted groups are the actual positional arguments; +# recover their element types from the (concrete) tuple types. function _recurse_apply(call, ci, seen, depth) args = call.args fpos = args[1].name === :_apply_iterate ? 3 : 2 - length(args) >= fpos || return false + length(args) >= fpos || return NOBRANCH ftype, fval = _resolve_callee(args[fpos], ci) - ftype === nothing && return false + ftype === nothing && return NOBRANCH argtypes = Any[] for a in @view args[(fpos + 1):end] at = _value_type(a, ci) if at isa DataType && at <: Tuple && Base.isconcretetype(at) append!(argtypes, at.parameters) else - # Splatted container whose element types we cannot recover statically (e.g. a - # non-`isbits` `Vararg` tuple or an array): bail rather than guess a wrong signature. - return false + # Splatted container whose element types we cannot recover statically: bail rather than + # guess a wrong signature. + return NOBRANCH end end - return _recurse_sig(Tuple{ftype, argtypes...}, fval, seen, depth) + # Splatted arguments are recovered from tuple element *types*; constants are not available here. + return _recurse_sig(Tuple{ftype, argtypes...}, fval, Any[argtypes...], seen, depth) end -function _recurse_sig(@nospecialize(callsig), @nospecialize(fval), seen, depth) +function _recurse_sig(@nospecialize(callsig), @nospecialize(fval), arglat, seen, depth) # Honor user `is_leaf` overrides when the concrete function value is recoverable. - fval !== nothing && is_leaf(fval) && return false - # Signature-level overrides: exemptions that depend on the argument types. - is_leaf_sig(callsig) && return false + fval !== nothing && is_leaf(fval) && return NOBRANCH m = try Base.which(callsig) catch - return false + return NOBRANCH + end + # Known Base callable wrappers are unwrapped before the library check would swallow them. + r = _unwrap_wrapper(callsig, seen, depth) + r === nothing || return r + _is_library_method(m) && return NOBRANCH + # Out of depth budget: "could be branching", never "assume a leaf". Refutation manufactures + # depth (one level per constant-recursion step), and a branch-free backstop here let a + # refutation cascade silently fold away a genuine value-dependent branch sitting below the + # cutoff (e.g. at the base of a 400-deep constant-recursion tower). + depth + 1 > RECURSION_LIMIT && return LIMITED + with_consts = depth <= REFUTATION_DEPTH_LIMIT && _const_prop_capable() && + any(x -> x isa Core.Const, arglat) + argtypes = with_consts ? + Any[fval !== nothing ? Core.Const(fval) : _first_param(callsig), arglat...] : nothing + ck = argtypes === nothing ? nothing : _const_key(argtypes) + # Successful refutations are memoized: a refutation that succeeded is path-independent -- the + # cycle marker below can only inject conservative "branch" verdicts, which would have made it + # fail -- so its result is reusable anywhere, and it subsumes the type-level scan. Failed + # refutations are not memoized, since they can be an artifact of the marker on the current + # path. Without this memo, constant-recursion towers re-analyze and re-refute the identical + # (sig, constants) on every re-scan, which is quadratic in tower depth. + ck !== nothing && (:refuted, callsig, ck) in seen && return NOBRANCH + # The type recursion is the source of truth. If it finds no branch, we are done. + res = _hasbranching(callsig, seen, depth + 1) + res == NOBRANCH && return NOBRANCH + # Refutation is attempted only for `BRANCH` -- a branch that was actually seen and that the + # constant arguments may decide. A `LIMITED` result cannot be refuted (the refutation's own + # scan would exhaust the same budget and fail), so attempting one only pays for a doomed + # re-descent. Refutation can only downgrade a reported branch to branch-free, never the + # reverse, and it is skipped entirely (leaving the branch reported) when there are no constant + # arguments, when the compiler internals do not cooperate, or when the constant inference + # errors. + if res == BRANCH && ck !== nothing + # A transient path marker breaks refutation cycles: a constant-recursive callee whose + # folded body reaches the same (sig, constants) again must not re-enter refutation (it + # previously recursed until stack overflow, which the error handling then converted into + # "refuted" -- a false negative). Hitting the marker leaves the branch reported. + key = (:refute, callsig, ck) + if !(key in seen) + push!(seen, key) + refuted = try + _const_refutes(callsig, argtypes, seen, depth + 1) + finally + delete!(seen, key) + end + if refuted + push!(seen, (:refuted, callsig, ck)) + return NOBRANCH + end + end + end + return res +end + +# Re-infer `sig` with the constant lattice elements preserved and report whether the result is +# branch-free. The scan shares the caller's `seen`, so proven-branch-free sigs are reused and +# nested refutations bump `depth`, keeping the recursion bounded by `RECURSION_LIMIT`. Returns +# `false` -- i.e. does not refute -- whenever the constant inference is unavailable, fails, hits +# the depth budget (`LIMITED` is "could be branching"), or leaves a branch, so an inability to +# fold never suppresses a genuine branch. +function _const_refutes(@nospecialize(sig), argtypes, seen, depth) + depth > RECURSION_LIMIT && return false + src = _const_infer_src(sig, argtypes) + src isa Core.CodeInfo || return false + return try + _scan_codeinfo(src, seen, depth) == NOBRANCH + catch + false + end +end + +# The refute marker must be cheap and total to hash, and must never run user code: constants are +# keyed by `objectid`, which is egal-based (equal isbits values and identical mutables map to the +# same id) -- exact for the marker's purpose. Keying by value hashed with `Base.hash`/`isequal` +# was a stack overflow for self-referential constants, O(length) for large ones, and an uncaught +# user exception for types with throwing `hash`/`==` overloads. +_const_key(argtypes) = map(argtypes) do x + x isa Core.Const ? (true, objectid(x.val)) : (false, x) +end + +_first_param(@nospecialize(sig)) = + (sig isa DataType && !isempty(sig.parameters)) ? sig.parameters[1] : Any +_lat_type(@nospecialize(x)) = x isa Core.Const ? Core.Typeof(x.val) : x + +# Argument lattice element: a `Core.Const` when the argument is a compile-time constant, otherwise +# the widened type. Preserving the `Core.Const` is what lets a constant index survive the recursion +# boundary so `_const_refutes` can fold the branch it decides. +function _arg_lattice(@nospecialize(a), ci) + if a isa Core.SSAValue + t = ci.ssavaluetypes[a.id] + return t isa Core.Const ? t : _widen(t) + elseif a isa Core.Argument + st = ci.slottypes + st === nothing && return Any + t = st[a.n] + return t isa Core.Const ? t : _widen(t) + elseif a isa Core.SlotNumber + st = ci.slottypes + st === nothing && return Any + t = st[a.id] + return t isa Core.Const ? t : _widen(t) + elseif a isa GlobalRef + return (isdefined(a.mod, a.name) && isconst(a.mod, a.name)) ? + Core.Const(getglobal(a.mod, a.name)) : Any + elseif a isa QuoteNode + return Core.Const(a.value) + elseif a isa Expr || a isa Core.GotoNode || a isa GotoIfNot || + a isa Core.NewvarNode || a isa Core.ReturnNode + return Any + else + # Raw literal constant embedded in the IR (e.g. an `Int` index). + return Core.Const(a) + end +end + +# ---- constant-argument inference ----------------------------------------------------------- + +# Run inference on `sig` with the given argument lattice (some `Core.Const`) preserved, and return +# the inferred `CodeInfo`, or `nothing` if the compiler internals do not cooperate. The `:no` +# (non-caching) mode is tried first because it skips the optimizer on 1.12 and 1.13, yielding the +# unoptimized inferred body where a constant-decided branch appears as a `GotoIfNot` with a +# `Core.Const` condition and nothing is inlined into view; it also writes nothing into the +# inference cache. `:volatile` is kept as a fallback for compiler versions where `:no` does not +# produce a scannable body -- under it the optimizer may run, which is still sound for refutation +# (inlined library branches can only make refutation fail, i.e. leave the branch reported) but is +# not the preferred shape. The `InferenceState` construction differs across versions (1.13 wants +# the uninferred source passed explicitly), so the explicit-source form is tried before the +# 3-argument form, and the body is read from whichever of `frame.src`/`result.src` is a `CodeInfo`. +function _const_infer_src(@nospecialize(sig), argtypes) + m = try + Base.which(sig) + catch + return nothing end - _is_library_method(m) && return false - return _hasbranching(callsig, seen, depth + 1) + mi = try + Base.specialize_method(m, sig, Core.svec()) + catch + return nothing + end + overridden = BitVector(x isa Core.Const for x in argtypes) + src0 = try + _CC.retrieve_code_info(mi, Base.get_world_counter()) + catch + nothing + end + # A fresh `InferenceResult`/`InferenceState` per attempt: an `InferenceResult` cannot be + # re-inferred once used. `copy(src0)` for the same reason -- inference mutates the source. + for cache_mode in (:no, :volatile) + for explicit_src in (true, false) + explicit_src && !(src0 isa Core.CodeInfo) && continue + src = try + interp = _CC.NativeInterpreter() + frame = explicit_src ? + _CC.InferenceState( + _new_result(mi, argtypes, overridden), copy(src0), cache_mode, interp + ) : + _CC.InferenceState(_new_result(mi, argtypes, overridden), cache_mode, interp) + frame === nothing && continue + _CC.typeinf(interp, frame) + _inferred_src(frame) + catch + nothing + end + src isa Core.CodeInfo && return src + end + end + return nothing +end + +_new_result(mi, argtypes, overridden) = _CC.InferenceResult(mi, Any[argtypes...], overridden) + +function _inferred_src(frame) + if isdefined(frame, :src) && getfield(frame, :src) isa Core.CodeInfo + return getfield(frame, :src) + end + r = getfield(frame, :result) + return (r isa _CC.InferenceResult && r.src isa Core.CodeInfo) ? r.src : nothing end +_count_nonconst_gotoifnot(ci::Core.CodeInfo) = + count(s -> isa(s, GotoIfNot) && !_is_const_gotoifnot(s, ci), ci.code) + +# `nothing` until the functional probe has run; then `true`/`false`. +const _CONST_PROP_CAPABLE = Ref{Union{Nothing, Bool}}(nothing) + +# Fixture with a branch decided purely by a constant integer index -- the shape the constant-argument +# recursion must fold. Used only by the capability probe. +struct _ProbeContainer + a::Int + b::Int +end +@generated function _probe_indexed(x::_ProbeContainer, idx::Int) + return quote + if idx == 1 + return x.a + else + return x.b + end + end +end + +# Verify, on the running Julia, that constant inference actually folds a constant-decided branch: +# the constant-index call must come back branch-free while the widened-index call must not. If the +# compiler internals we depend on have shifted shape, this returns `false` and the constant-argument +# recursion stays inert (behaviour identical to the plain type recursion). Probed once, then cached. +function _probe_const_prop() + sig = Tuple{typeof(_probe_indexed), _ProbeContainer, Int} + folded = _const_infer_src(sig, Any[Core.Const(_probe_indexed), _ProbeContainer, Core.Const(1)]) + widened = _const_infer_src(sig, Any[Core.Const(_probe_indexed), _ProbeContainer, Int]) + folded isa Core.CodeInfo || return false + widened isa Core.CodeInfo || return false + return _count_nonconst_gotoifnot(folded) == 0 && _count_nonconst_gotoifnot(widened) > 0 +end + +function _const_prop_capable() + v = _CONST_PROP_CAPABLE[] + if v === nothing + v = try + _probe_const_prop() + catch + false + end + _CONST_PROP_CAPABLE[] = v + end + return v +end + +# ---- callee/argument resolution ------------------------------------------------------------ + # Library code (`Base`, `Core`, stdlibs) is treated as a leaf: its branches are structural or # compile-time, not the value-dependent user logic `hasbranching` is meant to surface. function _is_library_method(m::Method) @@ -254,6 +581,6 @@ _widen(@nospecialize t) = t isa Core.PartialStruct ? t.typ : isa(t, Type) ? t : Any -export hasbranching, is_leaf, is_leaf_sig +export hasbranching, is_leaf end diff --git a/test/core_tests.jl b/test/core_tests.jl index 376e831..d7208b0 100644 --- a/test/core_tests.jl +++ b/test/core_tests.jl @@ -144,18 +144,134 @@ splat_forward_free(args...) = splat_target_free(args...) @test !FunctionProperties.hasbranching(splat_forward_free, -1.0) # --------------------------------------------------------------------------------------------- -# `is_leaf_sig`: signature-level exemptions for value-independent plumbing. +# Constant-decided branches (value-independent) must not be reported. # # A branch on an integer index that selects a buffer (the MTK `getindex(::MTKParameters, ::Int)` # pattern) is value-independent: each real call site passes a literal index that constant-folds the -# branch, but the recursion only sees the widened `Int` and so reports it. Such a call can be marked -# branch-free by signature. -struct TwoBuffers - a::Float64 - b::Float64 -end -@noinline select_buffer(c::TwoBuffers, i::Int) = i == 1 ? c.a : c.b -rhs_with_plumbing(u, p, t) = select_buffer(p, 1) * u -@test FunctionProperties.hasbranching(rhs_with_plumbing, 1.0, TwoBuffers(1.0, 2.0), 0.0) -FunctionProperties.is_leaf_sig(::Type{<:Tuple{typeof(select_buffer), TwoBuffers, Vararg}}) = true -@test !FunctionProperties.hasbranching(rhs_with_plumbing, 1.0, TwoBuffers(1.0, 2.0), 0.0) +# branch, but ordinary recursion widens the `Int` and so reports it. The constant-argument recursion +# re-infers the callee with the constant preserved so the branch folds away — where the running +# Julia's compiler cooperates (`_const_prop_capable()`). It stays conservative: a genuinely +# value-dependent branch, and a dynamic (non-constant) index, are always reported. +struct TwoBufferParams + a::Vector{Float64} + b::Vector{Float64} +end +@generated function pick_buffer(p::TwoBufferParams, idx::Int) + return quote + if idx == 1 + return p.a + elseif idx == 2 + return p.b + else + throw(BoundsError(p, idx)) + end + end +end +cp_relu(x) = x > 0 ? x : zero(x) +rhs_const_index(p) = @inbounds pick_buffer(p, 1)[1] +rhs_dynamic_index(p, i) = @inbounds pick_buffer(p, i)[1] +rhs_real_branch(u, p) = cp_relu(u) + @inbounds pick_buffer(p, 1)[1] +tbp = TwoBufferParams([1.0], [2.0]) + +@test FunctionProperties.hasbranching(rhs_real_branch, 1.0, tbp) # genuine branch: always reported +@test FunctionProperties.hasbranching(rhs_dynamic_index, tbp, 1) # dynamic index: always reported +if FunctionProperties._const_prop_capable() + @test !FunctionProperties.hasbranching(rhs_const_index, tbp) # constant index folds away +end + +# Refutation must be per-call-site: the same widened callsig can be refutable at one call site +# (constant index) and not at another (dynamic index). Memoizing the sig on the true-returning +# chain produced order-dependent false negatives (const-site first suppressed the dynamic site). +rhs_mixed_cd(p, i) = pick_buffer(p, 1)[1] + pick_buffer(p, i)[1] +rhs_mixed_dc(p, i) = pick_buffer(p, i)[1] + pick_buffer(p, 1)[1] +@test FunctionProperties.hasbranching(rhs_mixed_cd, tbp, 2) +@test FunctionProperties.hasbranching(rhs_mixed_dc, tbp, 2) + +# A constant-recursive callee must not send the refutation into unbounded recursion (previously a +# stack overflow inside inference, which the error handling then converted into a false negative). +# The refutation cycle is broken conservatively, so the branch stays reported -- and quickly. +recur_const(p, n) = n == 0 ? p.a : recur_const(p, 5) +rhs_recur(p) = recur_const(p, 5)[1] +mutual_a(p, n) = n == 0 ? p.a : mutual_b(p, 4) +mutual_b(p, n) = n == 1 ? p.b : mutual_a(p, 3) +rhs_mutual(p) = mutual_a(p, 5)[1] +@test FunctionProperties.hasbranching(rhs_recur, tbp) +@test FunctionProperties.hasbranching(rhs_mutual, tbp) + +# A value-dependent branch at the base of a constant-recursion tower must never be lost, even when +# the tower exceeds the refutation depth budget: exhausting the budget fails refutation ("cannot +# verify" reports the branch) rather than assuming a leaf. Previously the depth backstop returned +# branch-free, which a refutation cascade silently propagated into a false negative. +hidden_base(p, n, x) = n == 0 ? (x > 0 ? p.a : p.b) : hidden_base(p, n - 1, x) +rhs_hidden5(p, x) = hidden_base(p, 5, x)[1] +rhs_hidden400(p, x) = hidden_base(p, 400, x)[1] +@test FunctionProperties.hasbranching(rhs_hidden5, tbp, 0.5) +@test FunctionProperties.hasbranching(rhs_hidden400, tbp, 0.5) + +# Constant-decided recursion folds fully below the depth budget, and is conservatively reported +# above it. A diverging constant recursion (`n + 1`) must terminate via the depth budget. +cnt_const(p, n) = n == 0 ? p.a : cnt_const(p, n - 1) +rhs_cnt5(p) = cnt_const(p, 5)[1] +rhs_cnt400(p) = cnt_const(p, 400)[1] +asc_const(p, n) = n == 0 ? p.a : asc_const(p, n + 1) +rhs_asc(p) = asc_const(p, 5)[1] +if FunctionProperties._const_prop_capable() + @test !FunctionProperties.hasbranching(rhs_cnt5, tbp) +end +@test FunctionProperties.hasbranching(rhs_cnt400, tbp) +@test FunctionProperties.hasbranching(rhs_asc, tbp) + +# The refutation path marker keys non-isbits constants by object identity: hashing the value +# itself stack-overflowed on self-referential constants (uncaught, escaping `hasbranching`) and +# was O(length) on large ones. The exact result is version-dependent (whether inference folds +# predicates on a mutable constant); the invariant is that the query completes. +const SELFREF_CONST = Any[] +push!(SELFREF_CONST, SELFREF_CONST) +selref_pick(p, v) = isempty(v) ? p.a : p.b +rhs_selfref(p) = selref_pick(p, SELFREF_CONST)[1] +@test FunctionProperties.hasbranching(rhs_selfref, tbp) isa Bool + +# `hasbranching` consulted from inside a generated-function expansion: reflection is restricted +# there, so the IR may be unobtainable -- the answer must then be the conservative "could be +# branching", not a silent branch-free (which made generators emit the wrong arm). +genhb_branchy(x) = x > 0 ? x : -x +@generated function gen_consults_hb(p) + return FunctionProperties.hasbranching(genhb_branchy, 1.0) ? :(p.a) : :(p.b) +end +@test gen_consults_hb(tbp) == tbp.a + +# Constants are keyed by `objectid` in the refutation machinery: user `hash`/`==` overloads must +# never run (a throwing overload previously escaped `hasbranching` as an uncaught exception). +struct EvilHashBits + x::Int +end +Base.hash(::EvilHashBits, ::UInt) = error("user hash must not be called") +Base.:(==)(::EvilHashBits, ::EvilHashBits) = error("user == must not be called") +evil_pick(p, e) = e.x == 1 ? p.a : p.b +rhs_evilhash(p) = evil_pick(p, EvilHashBits(1))[1] +@test FunctionProperties.hasbranching(rhs_evilhash, tbp) isa Bool + +# Loads from const-bound MUTABLES must stay unfolded (Julia's effects system guarantees this; +# lock it in): folding on current contents would turn into a false negative after mutation. +const MUT_FLAG = Ref(true) +mut_pick(p) = MUT_FLAG[] ? p.a : p.b +@test FunctionProperties.hasbranching(mut_pick, tbp) + +# An entry with nothing scannable -- e.g. an opaque closure, which has no entry in the method +# tables -- must answer the conservative "could be branching", not a silent branch-free. +const OC_BRANCHY = Base.Experimental.@opaque p -> p.a[1] > 0 ? p.a : p.b +@test FunctionProperties.hasbranching(OC_BRANCHY, tbp) + +# Base's callable wrappers delegate through Base-owned helpers, so the library boundary hid user +# branches inside them (`relu ∘ layer` reported branch-free). Known wrappers are unwrapped into +# component signatures under the normal policy: user components are scanned, Base components stay +# library leaves. +wrap_branchy(x) = x > 0 ? x : zero(x) +wrap_cmp(x, t) = x > t ? x : t +@test FunctionProperties.hasbranching(wrap_branchy ∘ identity, 1.0) +@test FunctionProperties.hasbranching(identity ∘ wrap_branchy, 1.0) +@test !FunctionProperties.hasbranching(abs2 ∘ identity, 1.0) +@test !FunctionProperties.hasbranching(sin ∘ identity, 1.0) # Base components stay leaves +@test FunctionProperties.hasbranching(Base.Fix1(wrap_cmp, 0.0), 1.0) +@test FunctionProperties.hasbranching(Base.Fix2(wrap_cmp, 0.0), 1.0) +@test !FunctionProperties.hasbranching(Base.Fix2(*, 2.0), 1.0) diff --git a/test/qa/qa.jl b/test/qa/qa.jl index a4190c4..2447108 100644 --- a/test/qa/qa.jl +++ b/test/qa/qa.jl @@ -1,15 +1,26 @@ using SciMLTesting, FunctionProperties, JET, Test # `hasbranching` is a compiler-introspection utility: it `code_typed`s `f` and scans the -# resulting typed IR for value-dependent branches, so it necessarily reaches into the -# `Core`/`Base` IR and inference internals, none of which have a public equivalent: +# resulting typed IR for value-dependent branches, and -- to fold branches that constant +# arguments decide -- re-runs inference with `Core.Const` argument lattices preserved. It +# therefore necessarily reaches into the `Core`/`Base` IR and inference internals, none of +# which have a public equivalent: # - `GotoIfNot` (explicit import via `using Core: GotoIfNot`) is the conditional-branch IR node. -# - `CodeInfo`/`SSAValue`/`SlotNumber`/`Argument` are typed-IR node types scanned in the body. +# - `CodeInfo`/`SSAValue`/`SlotNumber`/`Argument`/`GotoNode`/`NewvarNode`/`ReturnNode` are +# typed-IR node types scanned in the body. # - `Const`/`PartialStruct` are inference lattice element types read off the IR. -# - `MethodInstance` is the resolved-call type used to recurse through static calls. +# - `MethodInstance` is the resolved-call type used to recurse through static calls, and +# `svec` builds the empty sparam vector for `specialize_method`. # - `Typeof` builds the dispatch signature (`typeof` differs from `Core.Typeof` on # type-valued arguments, and `Base.typesof` is itself non-public). -# - `code_typed_by_type` is the non-public typed-IR entry point (`Base.code_typed_by_type`). +# - `code_typed_by_type` is the non-public typed-IR entry point (`Base.code_typed_by_type`), +# and `specialize_method`/`get_world_counter` are the reflection pieces needed to build a +# method instance for constant re-inference. +# - `Compiler`/`NativeInterpreter`/`InferenceResult`/`InferenceState`/`typeinf`/ +# `retrieve_code_info` are the abstract interpreter: there is no public API for "infer this +# method with constant argument types". This dependency is deliberately confined behind a +# functional capability probe (`_const_prop_capable`) so the package degrades to the plain +# type scan wherever these internals change shape. # All of these are Core/Base compiler-introspection internals with no public API, so they are # ignored in the public-API checks. run_qa( @@ -19,8 +30,11 @@ run_qa( all_explicit_imports_are_public = (; ignore = (:GotoIfNot,)), all_qualified_accesses_are_public = (; ignore = ( - :Typeof, :Argument, :CodeInfo, :Const, :MethodInstance, - :PartialStruct, :SSAValue, :SlotNumber, :code_typed_by_type, + :Typeof, :Argument, :CodeInfo, :Compiler, :Const, :GotoNode, + :InferenceResult, :InferenceState, :MethodInstance, :NativeInterpreter, + :NewvarNode, :PartialStruct, :ReturnNode, :SSAValue, :SlotNumber, + :code_typed_by_type, :get_world_counter, :retrieve_code_info, + :specialize_method, :svec, :typeinf, ), ), )