Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "FunctionProperties"
uuid = "f62d2435-5019-4c03-9749-2d4c77af0cbc"
version = "1.0.0"
version = "1.1.0"
authors = ["SciML"]

[deps]
Expand Down
2 changes: 2 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,7 @@

```@docs
hasbranching
islinear
isquadratic
is_leaf
```
7 changes: 7 additions & 0 deletions docs/src/assets/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[deps]
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
FunctionProperties = "f62d2435-5019-4c03-9749-2d4c77af0cbc"

[compat]
Documenter = "1"
FunctionProperties = "1"
585 changes: 7 additions & 578 deletions src/FunctionProperties.jl

Large diffs are not rendered by default.

578 changes: 578 additions & 0 deletions src/hasbranching.jl

Large diffs are not rendered by default.

174 changes: 174 additions & 0 deletions src/polydegree.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
# Certification by abstract interpretation over polynomial degrees, in the style of tracer-type
# sparsity detection (Gowda et al., "Sparsity Programming", NeurIPS 2019 program-transformations
# workshop): the tracked arguments are seeded with degree-1 tracer numbers and the program is
# executed on them, propagating a sound upper bound on the real-arithmetic polynomial degree.
# Non-polynomial operations on non-constant values poison the result, and every value-inspecting
# predicate on a tracer (comparisons, `isnan`, rounding, conversion) throws, so any computation
# whose control flow or value flow would need the actual numbers aborts the trace instead of
# producing an unsound certificate. `hasbranching` is additionally required to prove the traced
# path is the only path.

# Degrees saturate at `_NOTPOLY` ("not a polynomial of any bounded degree").
const _NOTPOLY = typemax(Int) ÷ 2

struct PolyDegree <: Real
d::Int
end

_satadd(a::Int, b::Int) = a >= _NOTPOLY || b >= _NOTPOLY ? _NOTPOLY : a + b
_satmul(a::Int, b::Int) = a >= _NOTPOLY || b >= _NOTPOLY ? _NOTPOLY : min(a * b, _NOTPOLY)

Base.promote_rule(::Type{PolyDegree}, ::Type{<:Real}) = PolyDegree
Base.convert(::Type{PolyDegree}, x::PolyDegree) = x
Base.convert(::Type{PolyDegree}, x::Real) = PolyDegree(0)
PolyDegree(x::PolyDegree) = x
# Disambiguators against Base's cross-family `Number` constructors: any value converted into the
# tracer domain is by definition not a seeded variable, hence a constant (degree 0).
PolyDegree(::Base.TwicePrecision) = PolyDegree(0)
PolyDegree(::Complex) = PolyDegree(0)
PolyDegree(::AbstractChar) = PolyDegree(0)
Base.zero(::Type{PolyDegree}) = PolyDegree(0)
Base.zero(::PolyDegree) = PolyDegree(0)
Base.one(::Type{PolyDegree}) = PolyDegree(0)
Base.one(::PolyDegree) = PolyDegree(0)
Base.oneunit(::Type{PolyDegree}) = PolyDegree(0)
Base.float(x::PolyDegree) = x
Base.widen(::Type{PolyDegree}) = PolyDegree

Base.:+(a::PolyDegree, b::PolyDegree) = PolyDegree(max(a.d, b.d))
Base.:-(a::PolyDegree, b::PolyDegree) = PolyDegree(max(a.d, b.d))
Base.:-(a::PolyDegree) = a
Base.:+(a::PolyDegree) = a
Base.:*(a::PolyDegree, b::PolyDegree) = PolyDegree(_satadd(a.d, b.d))
Base.:/(a::PolyDegree, b::PolyDegree) = b.d == 0 ? a : PolyDegree(_NOTPOLY)
Base.:\(a::PolyDegree, b::PolyDegree) = a.d == 0 ? b : PolyDegree(_NOTPOLY)
Base.inv(a::PolyDegree) = a.d == 0 ? a : PolyDegree(_NOTPOLY)
Base.muladd(a::PolyDegree, b::PolyDegree, c::PolyDegree) = a * b + c
Base.fma(a::PolyDegree, b::PolyDegree, c::PolyDegree) = a * b + c
Base.abs2(a::PolyDegree) = a * a
Base.conj(a::PolyDegree) = a
Base.real(a::PolyDegree) = a

function Base.:^(a::PolyDegree, n::Integer)
n == 0 && return PolyDegree(0)
n > 0 && return PolyDegree(_satmul(a.d, Int(n)))
return a.d == 0 ? a : PolyDegree(_NOTPOLY)
end
Base.:^(a::PolyDegree, b::PolyDegree) =
a.d == 0 && b.d == 0 ? PolyDegree(0) : PolyDegree(_NOTPOLY)

# Non-polynomial scalar functions: constants map to constants; anything else poisons.
for fn in (
:sqrt, :cbrt, :exp, :exp2, :exp10, :expm1, :log, :log2, :log10, :log1p,
:sin, :cos, :tan, :asin, :acos, :atan, :sinh, :cosh, :tanh, :asinh, :acosh,
:atanh, :sinpi, :cospi, :sec, :csc, :cot, :abs, :sign,
)
@eval Base.$fn(a::PolyDegree) = a.d == 0 ? a : PolyDegree(_NOTPOLY)
end
Base.atan(a::PolyDegree, b::PolyDegree) =
a.d == 0 && b.d == 0 ? PolyDegree(0) : PolyDegree(_NOTPOLY)
Base.hypot(a::PolyDegree, b::PolyDegree) =
a.d == 0 && b.d == 0 ? PolyDegree(0) : PolyDegree(_NOTPOLY)
Base.mod(a::PolyDegree, b::PolyDegree) = PolyDegree(_NOTPOLY)
Base.rem(a::PolyDegree, b::PolyDegree) = PolyDegree(_NOTPOLY)

struct DegreeTracerError <: Exception
op::Symbol
end
function Base.showerror(io::IO, e::DegreeTracerError)
return print(
io, "DegreeTracerError: `", e.op, "` needs the value of a traced number, which a ",
"degree tracer does not carry; the polynomial-degree certificate is abandoned."
)
end

# Every predicate or conversion that would need the traced VALUE aborts the trace: allowing any
# of these to answer would let value-dependent control or value flow leak into the certificate.
for fn in (:isless, :(==), :<, :(<=), :isequal)
@eval Base.$fn(a::PolyDegree, b::PolyDegree) = throw(DegreeTracerError($(QuoteNode(fn))))
end
for fn in (:isnan, :isinf, :isfinite, :iszero, :isone, :signbit, :isinteger)
@eval Base.$fn(a::PolyDegree) = throw(DegreeTracerError($(QuoteNode(fn))))
end
for fn in (:floor, :ceil, :trunc, :round)
@eval Base.$fn(a::PolyDegree) = throw(DegreeTracerError($(QuoteNode(fn))))
end

_seed(x::Real) = PolyDegree(1)
_seed(x::AbstractArray{<:Real}) = map(_ -> PolyDegree(1), x)
_seed(x) = throw(DegreeTracerError(:seed))

_max_degree(y::PolyDegree) = y.d
_max_degree(y::Real) = 0
_max_degree(y::AbstractArray) = isempty(y) ? 0 : maximum(_max_degree, y)
_max_degree(y::Tuple) = isempty(y) ? 0 : maximum(_max_degree, y)
_max_degree(@nospecialize(y)) = _NOTPOLY

_wrt_indices(wrt::Integer, n) = (Int(wrt),)
_wrt_indices(wrt::Colon, n) = ntuple(identity, n)
_wrt_indices(wrt, n) = Tuple(Int.(collect(wrt)))

# Certified upper bound on the real-arithmetic polynomial degree of `f` in the arguments selected
# by `wrt`, or `_NOTPOLY` when no certificate can be produced (non-polynomial operations reached
# non-constant values, the trace aborted, or `f` branches on values).
function _degree_bound(f, args, wrt)
idx = _wrt_indices(wrt, length(args))
all(i -> 1 <= i <= length(args), idx) || throw(ArgumentError("wrt index out of range"))
d = try
targs = ntuple(i -> i in idx ? _seed(args[i]) : args[i], length(args))
_max_degree(f(targs...))
catch
_NOTPOLY
end
d < _NOTPOLY || return _NOTPOLY
# The trace certifies the executed path; `hasbranching` certifies it is the only path.
return hasbranching(f, args...) ? _NOTPOLY : d
end

"""
islinear(f, x...; wrt = 1) -> Bool

Attempt to *prove* that `f` is an affine (polynomial degree ≤ 1) function of the arguments
selected by `wrt` (an index, collection of indices, or `:` for all; default the first argument),
holding the remaining arguments fixed at the values given. Arrays are tracked elementwise.

`true` is a certificate under real arithmetic: the degree bound is established by abstract
interpretation with degree-tracking tracer numbers, and [`hasbranching`](@ref) additionally
proves the traced path is the only path. `false` means *not proven* -- `f` may still be linear
(e.g. the bound does not model cancellation: `x^2 - x^2 + x` is not certified), so use `false`
as "fall back to the general path", never as a proof of nonlinearity.

```jldoctest
julia> using FunctionProperties

julia> islinear((u, p, t) -> p[1] * u[1] + p[2], [1.0], [2.0, 3.0], 0.0)
true

julia> islinear((u, p, t) -> u[1] * u[2], [1.0, 2.0], nothing, 0.0)
false
```
"""
function islinear(f, x...; wrt = 1)
return _degree_bound(f, x, wrt) <= 1
end

"""
isquadratic(f, x...; wrt = 1) -> Bool

Attempt to *prove* that `f` is a polynomial of degree ≤ 2 in the arguments selected by `wrt`,
holding the remaining arguments fixed. Same certification semantics and conservatism as
[`islinear`](@ref): `true` is a proof under real arithmetic, `false` means not proven.

```jldoctest
julia> using FunctionProperties

julia> isquadratic((u, p, t) -> u[1] * u[2] + p[1] * u[1], [1.0, 2.0], [3.0], 0.0)
true

julia> isquadratic((u, p, t) -> exp(u[1]), [1.0], nothing, 0.0)
false
```
"""
function isquadratic(f, x...; wrt = 1)
return _degree_bound(f, x, wrt) <= 2
end
59 changes: 59 additions & 0 deletions test/core_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -275,3 +275,62 @@ wrap_cmp(x, t) = x > t ? x : t
@test FunctionProperties.hasbranching(Base.Fix1(wrap_cmp, 0.0), 1.0)
@test FunctionProperties.hasbranching(Base.Fix2(wrap_cmp, 0.0), 1.0)
@test !FunctionProperties.hasbranching(Base.Fix2(*, 2.0), 1.0)

# ---------------------------------------------------------------------------------------------
# `islinear` / `isquadratic`: degree certification by tracer-type abstract interpretation.
# `true` is a proof under real arithmetic; `false` is only "not proven".
A_lin = [1.0 2.0; 3.0 4.0]
@test islinear((u, p, t) -> p[1] * u[1] + p[2], [1.0], [2.0, 3.0], 0.0)
@test islinear(u -> A_lin * u, [1.0, 2.0]) # generic matmul certifies
@test islinear(u -> A_lin * u .+ 1.0, [1.0, 2.0])
@test islinear(x -> 2x + 3, 1.0)
@test islinear(x -> 0.0, 1.0) # constants are affine
@test !islinear((u, p, t) -> u[1] * u[2], [1.0, 2.0], nothing, 0.0)
@test isquadratic((u, p, t) -> u[1] * u[2] + p[1] * u[1], [1.0, 2.0], [3.0], 0.0)
@test isquadratic(x -> (x + 1.0)^2, 1.0)
@test !islinear(x -> (x + 1.0)^2, 1.0)
@test !isquadratic(x -> x^3, 1.0)
@test !isquadratic(u -> exp(u[1]), [1.0])
@test !islinear(u -> max.(u, 0.0), [1.0]) # tracer aborts on comparison
@test !islinear(x -> x > 0 ? x : zero(x), 1.0) # relu-style branch
@test !islinear(x -> x * x - x * x + x, 1.0) # cancellation: conservative, documented
# `wrt` semantics: joint degree in the tracked arguments, others held fixed.
@test islinear((u, v) -> u[1] + v[1], [1.0], [1.0]; wrt = (1, 2))
@test !islinear((u, v) -> u[1] * v[1], [1.0], [1.0]; wrt = (1, 2))
@test islinear((u, p) -> u[1] * p[1], [1.0], [2.0]) # linear in u for fixed p
@test islinear((u, p) -> u[1] * p[1], [1.0], [2.0]; wrt = :) == false
# A branch on an UNTRACKED argument still blocks certification (`hasbranching` guard): the
# function is linear in `u` for the given `p`, but the certificate is conservatively withheld.
@test !islinear((u, p, t) -> p[1] > 0 ? u[1] : 2u[1], [1.0], [1.0], 0.0)
# In-place right-hand sides via the closure pattern.
rhs_ip!(du, u, p, t) = (du[1] = p[1] * u[1]; du[2] = u[1] + u[2]; du)
@test islinear(u -> rhs_ip!(similar(u), u, [2.0], 0.0), [1.0, 2.0])
# Certified results are plain Bools; the tracer type does not leak.
@test islinear(x -> 2x, 1.0) isa Bool

# Ground-truth cross-validation: every `true` linear certificate must have exactly vanishing
# second finite differences over exact rational arithmetic (and third differences for quadratic
# certificates) at random rational points -- a soundness check independent of the tracer rules.
let rng = Random.Xoshiro(0x1517)
d2(f, x, h1, h2) = f(x + h1 + h2) - f(x + h1) - f(x + h2) + f(x)
corpusf = [
(x -> 3 // 2 * x + 7, true),
(x -> x * (x + 1) - x * x, true), # cancellation: genuinely linear...
(x -> (x + 2) * 5 - 3, true),
(x -> x^2 + x, false),
(x -> x^3 - x, false),
(x -> x * x * 2 + 1, false),
]
for (f, lin_truth) in corpusf
cert = islinear(f, 1.0)
# soundness: a certificate implies exact linearity at random rational probes
if cert
for _ in 1:3
x, h1, h2 = (Rational{BigInt}(rand(rng, -99:99)) // rand(rng, 1:9) for _ in 1:3)
@test iszero(d2(f, x, h1, h2))
end
end
# no false certificates on the known-nonlinear corpus
lin_truth || @test !cert
end
end
5 changes: 4 additions & 1 deletion test/qa/qa.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ using SciMLTesting, FunctionProperties, JET, Test
# method with constant argument types". This dependency is deliberately confined behind a
# functional capability probe (`_const_prop_capable`) so the package degrades to the plain
# type scan wherever these internals change shape.
# - `TwicePrecision` appears only in a constructor disambiguation that Aqua requires: Base
# defines `(::Type{T<:Number})(::Base.TwicePrecision)`, which is ambiguous against the
# degree tracer's generic constructor.
# All of these are Core/Base compiler-introspection internals with no public API, so they are
# ignored in the public-API checks.
run_qa(
Expand All @@ -32,7 +35,7 @@ run_qa(
ignore = (
:Typeof, :Argument, :CodeInfo, :Compiler, :Const, :GotoNode,
:InferenceResult, :InferenceState, :MethodInstance, :NativeInterpreter,
:NewvarNode, :PartialStruct, :ReturnNode, :SSAValue, :SlotNumber,
:NewvarNode, :PartialStruct, :ReturnNode, :SSAValue, :SlotNumber, :TwicePrecision,
:code_typed_by_type, :get_world_counter, :retrieve_code_info,
:specialize_method, :svec, :typeinf,
),
Expand Down
Loading