Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "FunctionProperties"
uuid = "f62d2435-5019-4c03-9749-2d4c77af0cbc"
version = "0.1.6"
version = "0.1.7"
authors = ["SciML"]

[deps]
Expand Down
90 changes: 87 additions & 3 deletions src/FunctionProperties.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,27 @@ FunctionProperties.is_leaf(::typeof(my_fn)) = true
"""
is_leaf(f, args...) = false

"""
is_leaf_sig(sig::Type{<:Tuple}) -> Bool

Signature-level counterpart to [`is_leaf`](@ref), consulted while recursing through statically
resolved calls. `sig` is the call's `Tuple{typeof(f), argtypes...}`. Return `true` to treat the
call as branch-free and stop recursing into it.

Use this (instead of `is_leaf`) when the exemption depends on the *argument types*, not just the
function. The motivating case is value-independent plumbing whose branch is on an index/type
rather than on traced values — e.g. selecting a buffer by integer index inside a parameter
container, where each real call site passes a literal index that constant-folds the branch away,
but the recursion only sees the widened argument type.

## Example

```julia
FunctionProperties.is_leaf_sig(::Type{<:Tuple{typeof(getindex), <:MyParamContainer, Vararg}}) = true
```
"""
is_leaf_sig(@nospecialize(sig)) = false

"""
hasbranching(f, x...)

Expand Down Expand Up @@ -75,14 +96,40 @@ function _hasbranching(@nospecialize(sig), seen, depth)
# Generated functions that were not expanded come back as `Method`, not `CodeInfo`;
# there is no body to scan, so treat them as leaves.
ci isa Core.CodeInfo || continue
any(stmt -> isa(stmt, GotoIfNot), ci.code) && return true
for stmt in ci.code
_recurse_call(stmt, ci, seen, depth) && return true
if isa(stmt, GotoIfNot)
_is_const_gotoifnot(stmt, ci) || return true
elseif _recurse_call(stmt, ci, seen, depth)
return true
end
end
end
return false
end

# A `GotoIfNot` whose condition type inference has *proven* constant is a compile-time branch,
# not a value-dependent one: e.g. an `x isa T` test on a concretely-typed field (the SciML
# `ODEFunction` wrapper) or the device/type-introspection dispatch inside ML library layers
# (SciML/FunctionProperties.jl#46). Such a branch can never be taken differently under a tracing
# AD, so it is not the branching `hasbranching` is meant to surface. A condition that is a literal
# `true`/`false` written directly into the IR is deliberately *not* skipped: that is a genuine
# syntactic branch in user code (e.g. `true ? a : b`). Only conditions inference resolved to a
# `Core.Const` value are dropped; anything we cannot positively prove constant is kept.
function _is_const_gotoifnot(stmt::GotoIfNot, ci)
cond = stmt.cond
t = if cond isa Core.SSAValue
types = ci.ssavaluetypes
types isa AbstractVector && checkbounds(Bool, types, cond.id) ? types[cond.id] : nothing
elseif cond isa Core.Argument
ci.slottypes === nothing ? nothing : get(ci.slottypes, cond.n, nothing)
elseif cond isa Core.SlotNumber
ci.slottypes === nothing ? nothing : get(ci.slottypes, cond.id, nothing)
else
nothing
end
return t isa Core.Const
end

# Inspect a single IR statement: if it is a statically resolvable call into a non-library
# method, recurse into that method's IR. Returns `true` if a branch is found downstream.
function _recurse_call(@nospecialize(stmt), ci, seen, depth)
Expand All @@ -100,15 +147,52 @@ function _recurse_call(@nospecialize(stmt), ci, seen, depth)
end

Meta.isexpr(call, :call) || return false
if _is_apply(call.args[1])
return _recurse_apply(call, ci, seen, depth)
end
ftype, fval = _resolve_callee(call.args[1], ci)
ftype === nothing && return false
argtypes = Any[_value_type(a, ci) for a in @view call.args[2:end]]
return _recurse_sig(Tuple{ftype, argtypes...}, fval, seen, depth)
end

_is_apply(@nospecialize(f)) =
f isa GlobalRef && f.mod === Core && (f.name === :_apply_iterate || f.name === :_apply)

# A splatted call `g(a, bs...)` lowers to `Core._apply_iterate(iter, g, groups...)` (or, on
# older lowerings, `Core._apply(g, groups...)`). The real callee `g` is therefore an *argument*
# of a `Core` builtin, so the plain `:call` path would resolve the callee to `_apply_iterate`,
# treat it as library, and dead-end — missing every branch behind the forwarder. SciML/MTK RHS
# objects are exactly such forwarders (`ODEFunction` -> `GeneratedFunctionWrapper` ->
# `RuntimeGeneratedFunction` -> `generated_callfunc`, each `f(args...)`), so the generated body's
# branches only become reachable by following the apply through to `g`. The splatted groups are
# the actual positional arguments; recover their element types from the (concrete) tuple types so
# the right method specialization is selected downstream.
function _recurse_apply(call, ci, seen, depth)
args = call.args
fpos = args[1].name === :_apply_iterate ? 3 : 2
length(args) >= fpos || return false
ftype, fval = _resolve_callee(args[fpos], ci)
ftype === nothing && return false
argtypes = Any[]
for a in @view args[(fpos + 1):end]
at = _value_type(a, ci)
if at isa DataType && at <: Tuple && Base.isconcretetype(at)
append!(argtypes, at.parameters)
else
# Splatted container whose element types we cannot recover statically (e.g. a
# non-`isbits` `Vararg` tuple or an array): bail rather than guess a wrong signature.
return false
end
end
return _recurse_sig(Tuple{ftype, argtypes...}, fval, seen, depth)
end

function _recurse_sig(@nospecialize(callsig), @nospecialize(fval), seen, depth)
# Honor user `is_leaf` overrides when the concrete function value is recoverable.
fval !== nothing && is_leaf(fval) && return false
# Signature-level overrides: exemptions that depend on the argument types.
is_leaf_sig(callsig) && return false
m = try
Base.which(callsig)
catch
Expand Down Expand Up @@ -170,6 +254,6 @@ _widen(@nospecialize t) =
t isa Core.PartialStruct ? t.typ :
isa(t, Type) ? t : Any

export hasbranching, is_leaf
export hasbranching, is_leaf, is_leaf_sig

end
57 changes: 57 additions & 0 deletions test/core_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,60 @@ end
x0 = [-4.0f0, 0.0f0]
ts = Float32.(collect(0.0:0.01:tspan[2]))
@test !FunctionProperties.hasbranching(dxdt_, copy(x0), x0, p, tspan[1])

# ---------------------------------------------------------------------------------------------
# Value-independent (compile-time-constant) branches must not be reported.
#
# A `GotoIfNot` whose condition inference proves `Core.Const` cannot be taken differently under a
# tracing AD, so it is wrapper/dispatch plumbing rather than the value-dependent branching
# `hasbranching` is meant to surface. This is the shape of the SciML `ODEFunction` functor
# (`if f.f isa AbstractSciMLOperator`) and of ML-library device/type-introspection dispatch
# (SciML/FunctionProperties.jl#46). A *literal* `true`/`false` condition is still a genuine branch
# and is kept (covered by the `f_branch` test above).
abstract type FakeOperator end
struct CondWrap{F}
f::F
end
function (w::CondWrap)(x)
if w.f isa FakeOperator # concretely-typed field => `isa` folds to a constant
return zero(x)
else
return w.f(x)
end
end
branchfree_inner(x) = x * x + one(x)
branchy_inner(x) = x < 0 ? -x : x
@test !FunctionProperties.hasbranching(CondWrap(branchfree_inner), 1.0) # const `isa` skipped
@test FunctionProperties.hasbranching(CondWrap(branchy_inner), 1.0) # real inner branch kept

# ---------------------------------------------------------------------------------------------
# Branches behind a *splatted* call boundary must be detected.
#
# `g(args...)` lowers to `Core._apply_iterate(iter, g, args)`, hiding the real callee `g` as an
# argument of a `Core` builtin. The scan must follow the apply through to `g`, otherwise every
# branch behind a splat forwarder is missed. This is the SciML/MTK RHS shape (`ODEFunction` ->
# `GeneratedFunctionWrapper` -> `RuntimeGeneratedFunction` -> `generated_callfunc`, each a
# `f(args...)` forwarder).
@noinline splat_target_branchy(x) = x < 0 ? -x : x
@noinline splat_target_free(x) = x * x
splat_forward_branchy(args...) = splat_target_branchy(args...)
splat_forward_free(args...) = splat_target_free(args...)
@test FunctionProperties.hasbranching(splat_forward_branchy, -1.0)
@test !FunctionProperties.hasbranching(splat_forward_free, -1.0)

# ---------------------------------------------------------------------------------------------
# `is_leaf_sig`: signature-level exemptions for value-independent plumbing.
#
# A branch on an integer index that selects a buffer (the MTK `getindex(::MTKParameters, ::Int)`
# pattern) is value-independent: each real call site passes a literal index that constant-folds the
# branch, but the recursion only sees the widened `Int` and so reports it. Such a call can be marked
# branch-free by signature.
struct TwoBuffers
a::Float64
b::Float64
end
@noinline select_buffer(c::TwoBuffers, i::Int) = i == 1 ? c.a : c.b
rhs_with_plumbing(u, p, t) = select_buffer(p, 1) * u
@test FunctionProperties.hasbranching(rhs_with_plumbing, 1.0, TwoBuffers(1.0, 2.0), 0.0)
FunctionProperties.is_leaf_sig(::Type{<:Tuple{typeof(select_buffer), TwoBuffers, Vararg}}) = true
@test !FunctionProperties.hasbranching(rhs_with_plumbing, 1.0, TwoBuffers(1.0, 2.0), 0.0)
Loading