Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "FunctionProperties"
uuid = "f62d2435-5019-4c03-9749-2d4c77af0cbc"
version = "0.1.7"
version = "0.1.8"
authors = ["SciML"]

[deps]
Expand Down
228 changes: 216 additions & 12 deletions src/FunctionProperties.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,85 @@ using Core: GotoIfNot
# Backstop against pathological recursion depth; real call trees that matter here are shallow.
const RECURSION_LIMIT = 256

# ---- experimental: constant-propagation-aware recursion ------------------------------------
#
# When recursing into a call, ordinary analysis widens every argument to its type, discarding
# constants. A branch decided by a constant argument (e.g. selecting a buffer by a literal index
# inside a parameter container) therefore looks value-dependent even though every real call site
# folds it. If we instead preserve the `Core.Const` argument and re-run *inference* (no optimizer,
# so no library/structural branches are inlined into view), such branches fold to `Core.Const`
# conditions that [`_is_const_gotoifnot`](@ref) already skips -- generalizing [`is_leaf_sig`](@ref)
# without any per-container knowledge.
#
# This relies on `Base.Compiler` (`Core.Compiler`) internals whose API churns across Julia versions
# (the `InferenceState` construction and inferred-source extraction already differ between 1.12 and
# 1.13), so it is *functionally* capability-gated -- see `_const_prop_capable` -- and OFF by
# default. Enable with [`enable_const_prop!`](@ref).
const _CC = isdefined(Base, :Compiler) ? Base.Compiler : Core.Compiler

const _CONST_PROP = Ref(false)
# `nothing` until the functional probe has run; then `true`/`false`.
const _CONST_PROP_CAPABLE = Ref{Union{Nothing, Bool}}(nothing)

# Fixture with a branch decided purely by a constant integer index -- the shape the feature must
# fold. Used only by the capability probe.
struct _ProbeContainer
a::Int
b::Int
end
@generated function _probe_indexed(x::_ProbeContainer, idx::Int)
quote
if idx == 1
return x.a
else
return x.b
end
end
end

# Verify, on the running Julia, that constant inference actually folds a constant-decided branch:
# the constant-index call must come back branch-free while the widened-index call must not. If the
# compiler internals we depend on have shifted shape, this returns `false` and the feature stays
# inert (so behaviour is identical to the plain type recursion). Probed once, then cached.
function _probe_const_prop()
sig = Tuple{typeof(_probe_indexed), _ProbeContainer, Int}
folded = _const_infer_src(sig, Any[Core.Const(_probe_indexed), _ProbeContainer, Core.Const(1)])
widened = _const_infer_src(sig, Any[Core.Const(_probe_indexed), _ProbeContainer, Int])
folded isa Core.CodeInfo || return false
widened isa Core.CodeInfo || return false
return _count_nonconst_gotoifnot(folded) == 0 && _count_nonconst_gotoifnot(widened) > 0
end

function _const_prop_capable()
v = _CONST_PROP_CAPABLE[]
if v === nothing
v = try
_probe_const_prop()
catch
false
end
_CONST_PROP_CAPABLE[] = v
end
return v
end

"""
enable_const_prop!(on::Bool = true) -> Bool

Experimental. Toggle constant-propagation-aware recursion in [`hasbranching`](@ref). When on (and
the running Julia's compiler internals still fold a constant-decided branch, verified functionally
by [`_const_prop_capable`](@ref)), a call with constant arguments is re-inferred with those
constants preserved, so value-independent branches decided by a constant (e.g. selecting a buffer
by a literal index) fold away instead of being reported. Off by default because it depends on
compiler internals. Returns the effective state.
"""
function enable_const_prop!(on::Bool = true)
_CONST_PROP[] = on
return _const_prop_active()
end

_const_prop_active() = _CONST_PROP[] && _const_prop_capable()

"""
is_leaf(f, args...) -> Bool

Expand Down Expand Up @@ -96,17 +175,100 @@ function _hasbranching(@nospecialize(sig), seen, depth)
# Generated functions that were not expanded come back as `Method`, not `CodeInfo`;
# there is no body to scan, so treat them as leaves.
ci isa Core.CodeInfo || continue
for stmt in ci.code
if isa(stmt, GotoIfNot)
_is_const_gotoifnot(stmt, ci) || return true
elseif _recurse_call(stmt, ci, seen, depth)
return true
end
_scan_codeinfo(ci, seen, depth) && return true
end
return false
end

function _scan_codeinfo(ci, seen, depth)
for stmt in ci.code
if isa(stmt, GotoIfNot)
_is_const_gotoifnot(stmt, ci) || return true
elseif _recurse_call(stmt, ci, seen, depth)
return true
end
end
return false
end

# Constant-argument recursion (experimental, see [`enable_const_prop!`](@ref)): re-infer the callee
# with the constant lattice elements preserved so branches decided by a constant argument fold to
# `Core.Const` conditions. Inference is run *without* the optimizer, so no library/structural
# branches are inlined into view. Falls back to the plain type recursion whenever the constant
# inference is unavailable or fails.
function _hasbranching_const(@nospecialize(sig), argtypes, seen, depth)
depth > RECURSION_LIMIT && return false
key = (sig, _const_key(argtypes))
key in seen && return false
push!(seen, key)
src = _const_infer_src(sig, argtypes)
src isa Core.CodeInfo || return _hasbranching(sig, seen, depth)
return _scan_codeinfo(src, seen, depth)
end

_const_key(argtypes) = map(x -> x isa Core.Const ? (true, x.val) : (false, x), argtypes)

# Run inference on `sig` with the given argument lattice (some `Core.Const`) preserved, and return
# the inferred (unoptimized) `CodeInfo`, or `nothing` if the compiler internals do not cooperate.
# The `InferenceState` construction and the inferred-source location differ across Julia versions:
# 1.12 accepts `InferenceState(result, cache_mode, interp)` and exposes the body on `result.src`,
# while 1.13 wants the uninferred source passed explicitly and exposes the body on `frame.src`. We
# try the explicit-source form first (works on both) with the non-caching `:volatile` mode, then
# fall back, and read whichever of `frame.src`/`result.src` is a `CodeInfo`. Any shape we don't
# recognise simply yields `nothing`, and the functional probe (`_const_prop_capable`) keeps the
# whole feature inert on such versions.
function _const_infer_src(@nospecialize(sig), argtypes)
m = try
Base.which(sig)
catch
return nothing
end
mi = try
Base.specialize_method(m, sig, Core.svec())
catch
return nothing
end
overridden = BitVector(x isa Core.Const for x in argtypes)
src0 = try
_CC.retrieve_code_info(mi, Base.get_world_counter())
catch
nothing
end
# A fresh `InferenceResult`/`InferenceState` per attempt: an `InferenceResult` cannot be
# re-inferred once used.
for build in (
interp -> src0 isa Core.CodeInfo ?
_CC.InferenceState(_new_result(mi, argtypes, overridden), src0, :volatile, interp) :
nothing,
interp -> _CC.InferenceState(_new_result(mi, argtypes, overridden), :volatile, interp),
)
src = try
interp = _CC.NativeInterpreter()
frame = build(interp)
frame === nothing && continue
_CC.typeinf(interp, frame)
_inferred_src(frame)
catch
nothing
end
src isa Core.CodeInfo && return src
end
return nothing
end

_new_result(mi, argtypes, overridden) = _CC.InferenceResult(mi, Any[argtypes...], overridden)

function _inferred_src(frame)
if isdefined(frame, :src) && getfield(frame, :src) isa Core.CodeInfo
return getfield(frame, :src)
end
r = getfield(frame, :result)
return (r isa _CC.InferenceResult && r.src isa Core.CodeInfo) ? r.src : nothing
end

_count_nonconst_gotoifnot(ci::Core.CodeInfo) =
count(s -> isa(s, GotoIfNot) && !_is_const_gotoifnot(s, ci), ci.code)

# A `GotoIfNot` whose condition type inference has *proven* constant is a compile-time branch,
# not a value-dependent one: e.g. an `x isa T` test on a concretely-typed field (the SciML
# `ODEFunction` wrapper) or the device/type-introspection dispatch inside ML library layers
Expand Down Expand Up @@ -143,7 +305,9 @@ function _recurse_call(@nospecialize(stmt), ci, seen, depth)
getfield(mi, :def).specTypes : nothing
)
callsig === nothing && return false
return _recurse_sig(callsig, nothing, seen, depth)
_, fval = _resolve_callee(call.args[2], ci)
arglat = Any[_arg_lattice(a, ci) for a in @view call.args[3:end]]
return _recurse_sig(callsig, fval, arglat, seen, depth)
end

Meta.isexpr(call, :call) || return false
Expand All @@ -152,8 +316,8 @@ function _recurse_call(@nospecialize(stmt), ci, seen, depth)
end
ftype, fval = _resolve_callee(call.args[1], ci)
ftype === nothing && return false
argtypes = Any[_value_type(a, ci) for a in @view call.args[2:end]]
return _recurse_sig(Tuple{ftype, argtypes...}, fval, seen, depth)
arglat = Any[_arg_lattice(a, ci) for a in @view call.args[2:end]]
return _recurse_sig(Tuple{ftype, (_lat_type(x) for x in arglat)...}, fval, arglat, seen, depth)
end

_is_apply(@nospecialize(f)) =
Expand Down Expand Up @@ -185,10 +349,11 @@ function _recurse_apply(call, ci, seen, depth)
return false
end
end
return _recurse_sig(Tuple{ftype, argtypes...}, fval, seen, depth)
# Splatted arguments are recovered from tuple element *types*; constants are not available here.
return _recurse_sig(Tuple{ftype, argtypes...}, fval, Any[argtypes...], seen, depth)
end

function _recurse_sig(@nospecialize(callsig), @nospecialize(fval), seen, depth)
function _recurse_sig(@nospecialize(callsig), @nospecialize(fval), arglat, seen, depth)
# Honor user `is_leaf` overrides when the concrete function value is recoverable.
fval !== nothing && is_leaf(fval) && return false
# Signature-level overrides: exemptions that depend on the argument types.
Expand All @@ -199,9 +364,48 @@ function _recurse_sig(@nospecialize(callsig), @nospecialize(fval), seen, depth)
return false
end
_is_library_method(m) && return false
if _const_prop_active() && any(x -> x isa Core.Const, arglat)
funclat = fval !== nothing ? Core.Const(fval) : _first_param(callsig)
return _hasbranching_const(callsig, Any[funclat, arglat...], seen, depth + 1)
end
return _hasbranching(callsig, seen, depth + 1)
end

_first_param(@nospecialize(sig)) =
(sig isa DataType && !isempty(sig.parameters)) ? sig.parameters[1] : Any
_lat_type(@nospecialize(x)) = x isa Core.Const ? Core.Typeof(x.val) : x

# Argument lattice element: a `Core.Const` when the argument is a compile-time constant, otherwise
# the widened type. Preserving the `Core.Const` is what lets a constant index survive the recursion
# boundary so `_hasbranching_const` can fold the branch it decides.
function _arg_lattice(@nospecialize(a), ci)
if a isa Core.SSAValue
t = ci.ssavaluetypes[a.id]
return t isa Core.Const ? t : _widen(t)
elseif a isa Core.Argument
st = ci.slottypes
st === nothing && return Any
t = st[a.n]
return t isa Core.Const ? t : _widen(t)
elseif a isa Core.SlotNumber
st = ci.slottypes
st === nothing && return Any
t = st[a.id]
return t isa Core.Const ? t : _widen(t)
elseif a isa GlobalRef
return (isdefined(a.mod, a.name) && isconst(a.mod, a.name)) ?
Core.Const(getglobal(a.mod, a.name)) : Any
elseif a isa QuoteNode
return Core.Const(a.value)
elseif a isa Expr || a isa Core.GotoNode || a isa GotoIfNot ||
a isa Core.NewvarNode || a isa Core.ReturnNode
return Any
else
# Raw literal constant embedded in the IR (e.g. an `Int` index).
return Core.Const(a)
end
end

# Library code (`Base`, `Core`, stdlibs) is treated as a leaf: its branches are structural or
# compile-time, not the value-dependent user logic `hasbranching` is meant to surface.
function _is_library_method(m::Method)
Expand Down Expand Up @@ -254,6 +458,6 @@ _widen(@nospecialize t) =
t isa Core.PartialStruct ? t.typ :
isa(t, Type) ? t : Any

export hasbranching, is_leaf, is_leaf_sig
export hasbranching, is_leaf, is_leaf_sig, enable_const_prop!

end
44 changes: 44 additions & 0 deletions test/core_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,47 @@ rhs_with_plumbing(u, p, t) = select_buffer(p, 1) * u
@test FunctionProperties.hasbranching(rhs_with_plumbing, 1.0, TwoBuffers(1.0, 2.0), 0.0)
FunctionProperties.is_leaf_sig(::Type{<:Tuple{typeof(select_buffer), TwoBuffers, Vararg}}) = true
@test !FunctionProperties.hasbranching(rhs_with_plumbing, 1.0, TwoBuffers(1.0, 2.0), 0.0)

# ---------------------------------------------------------------------------------------------
# Experimental: constant-propagation-aware recursion (`enable_const_prop!`).
#
# A branch decided by a *constant* argument (e.g. selecting a buffer by a literal index) is
# value-independent, but ordinary recursion widens the argument and reports it. With const-prop on,
# the callee is re-inferred with the constant preserved so such branches fold away. It stays
# conservative: a genuinely value-dependent branch, or a dynamic (non-constant) index, is still
# reported. Off by default, so it must not change any behavior unless explicitly enabled.
struct TwoBufferParams
a::Vector{Float64}
b::Vector{Float64}
end
@generated function pick_buffer(p::TwoBufferParams, idx::Int)
quote
if idx == 1
return p.a
elseif idx == 2
return p.b
else
throw(BoundsError(p, idx))
end
end
end
cp_relu(x) = x > 0 ? x : zero(x)
rhs_const_index(p) = @inbounds pick_buffer(p, 1)[1]
rhs_dynamic_index(p, i) = @inbounds pick_buffer(p, i)[1]
rhs_real_branch(u, p) = cp_relu(u) + @inbounds pick_buffer(p, 1)[1]
tbp = TwoBufferParams([1.0], [2.0])

# Default: off -> the value-independent index branch is (conservatively) still reported.
@test !FunctionProperties.enable_const_prop!(false)
@test FunctionProperties.hasbranching(rhs_const_index, tbp)

if FunctionProperties._const_prop_capable()
@test FunctionProperties.enable_const_prop!(true)
try
@test !FunctionProperties.hasbranching(rhs_const_index, tbp) # constant index folds
@test FunctionProperties.hasbranching(rhs_real_branch, 1.0, tbp) # genuine branch kept
@test FunctionProperties.hasbranching(rhs_dynamic_index, tbp, 1) # dynamic index: conservative
finally
FunctionProperties.enable_const_prop!(false)
end
end
Loading