[WIP] Add register_multivariate#241
Closed
andrewrosemberg wants to merge 2 commits intoexanauts:mainfrom
Closed
Conversation
Contributor
|
Your PR requires formatting changes to meet the project's style guidelines. Click here to view the suggested changes.diff --git a/src/gradient.jl b/src/gradient.jl
index 55bbb36..8812f9e 100644
--- a/src/gradient.jl
+++ b/src/gradient.jl
@@ -24,7 +24,7 @@ end
@inbounds y[d.i] += adj
nothing
end
-@generated function drpass(d::ExaModels.AdjointNodeN{F,N}, y, adj) where {F,N}
+@generated function drpass(d::ExaModels.AdjointNodeN{F, N}, y, adj) where {F, N}
stmts = [:(ExaModels.drpass(d.args[$k], y, adj * d.g[$k])) for k in 1:N]
return quote
$(stmts...)
@@ -92,13 +92,13 @@ end
return cnt
end
@generated function grpass(
- d::ExaModels.AdjointNodeN{F,N},
- comp,
- y,
- o1,
- cnt,
- adj,
-) where {F,N}
+ d::ExaModels.AdjointNodeN{F, N},
+ comp,
+ y,
+ o1,
+ cnt,
+ adj,
+ ) where {F, N}
stmts = [:(cnt = ExaModels.grpass(d.args[$k], comp, y, o1, cnt, adj * d.g[$k])) for k in 1:N]
return quote
$(stmts...)
diff --git a/src/graph.jl b/src/graph.jl
index 3122063..480280e 100644
--- a/src/graph.jl
+++ b/src/graph.jl
@@ -322,12 +322,12 @@ A node with N children for symbolic expression tree (multivariate registered fun
# Fields:
- `args::Args`: tuple of N children nodes
"""
-struct NodeN{F,N,Args} <: AbstractNode
+struct NodeN{F, N, Args} <: AbstractNode
args::Args
end
-@inline NodeN(f::F, args::Args) where {F,N,Args<:NTuple{N,Any}} =
- NodeN{F,N,Args}(args)
+@inline NodeN(f::F, args::Args) where {F, N, Args <: NTuple{N, Any}} =
+ NodeN{F, N, Args}(args)
"""
AdjointNodeN{F, N, T, Args}
@@ -339,14 +339,14 @@ A node with N children for first-order forward pass tree (multivariate registere
- `g::NTuple{N,T}`: first-order sensitivities (∂f/∂xᵢ) for each argument
- `args::Args`: tuple of N children `AbstractAdjointNode`s
"""
-struct AdjointNodeN{F,N,T,Args} <: AbstractAdjointNode
+struct AdjointNodeN{F, N, T, Args} <: AbstractAdjointNode
x::T
- g::NTuple{N,T}
+ g::NTuple{N, T}
args::Args
end
-@inline AdjointNodeN(f::F, x::T, g::NTuple{N,T}, args::Args) where {F,N,T,Args} =
- AdjointNodeN{F,N,T,Args}(x, g, args)
+@inline AdjointNodeN(f::F, x::T, g::NTuple{N, T}, args::Args) where {F, N, T, Args} =
+ AdjointNodeN{F, N, T, Args}(x, g, args)
"""
SecondAdjointNodeN{F, N, T, Args}
@@ -360,20 +360,20 @@ A node with N children for second-order forward pass tree (multivariate register
K = N*(N+1)÷2. Entry (i,j) with i≤j is at position i*(2N-i+1)÷2 + (j-i) + 1.
- `args::Args`: tuple of N children `AbstractSecondAdjointNode`s
"""
-struct SecondAdjointNodeN{F,N,K,T,Args} <: AbstractSecondAdjointNode
+struct SecondAdjointNodeN{F, N, K, T, Args} <: AbstractSecondAdjointNode
x::T
- g::NTuple{N,T}
- h::NTuple{K,T}
+ g::NTuple{N, T}
+ h::NTuple{K, T}
args::Args
end
@inline SecondAdjointNodeN(
f::F,
x::T,
- g::NTuple{N,T},
- h::NTuple{K,T},
+ g::NTuple{N, T},
+ h::NTuple{K, T},
args::Args,
-) where {F,N,K,T,Args} = SecondAdjointNodeN{F,N,K,T,Args}(x, g, h, args)
+) where {F, N, K, T, Args} = SecondAdjointNodeN{F, N, K, T, Args}(x, g, h, args)
# Upper-triangular index helper: (i,j) with 1-based i≤j
@inline _hess_index(i, j, N) = (i - 1) * N - (i - 1) * (i - 2) ÷ 2 + (j - i) + 1
diff --git a/src/hessian.jl b/src/hessian.jl
index 16c4be2..7ab89ee 100644
--- a/src/hessian.jl
+++ b/src/hessian.jl
@@ -250,15 +250,15 @@ end
# hdrpass for SecondAdjointNodeN paired with any other SecondAdjointNode type
@generated function hdrpass(
- t1::ExaModels.SecondAdjointNodeN{F,N},
- t2::T2,
- comp,
- y1,
- y2,
- o2,
- cnt,
- adj,
-) where {F,N,T2<:AbstractSecondAdjointNode}
+ t1::ExaModels.SecondAdjointNodeN{F, N},
+ t2::T2,
+ comp,
+ y1,
+ y2,
+ o2,
+ cnt,
+ adj,
+ ) where {F, N, T2 <: AbstractSecondAdjointNode}
stmts = [:(cnt = ExaModels.hdrpass(t1.args[$k], t2, comp, y1, y2, o2, cnt, adj * t1.g[$k])) for k in 1:N]
return quote
$(stmts...)
@@ -267,15 +267,15 @@ end
end
@generated function hdrpass(
- t1::T1,
- t2::ExaModels.SecondAdjointNodeN{F,N},
- comp,
- y1,
- y2,
- o2,
- cnt,
- adj,
-) where {T1<:AbstractSecondAdjointNode,F,N}
+ t1::T1,
+ t2::ExaModels.SecondAdjointNodeN{F, N},
+ comp,
+ y1,
+ y2,
+ o2,
+ cnt,
+ adj,
+ ) where {T1 <: AbstractSecondAdjointNode, F, N}
stmts = [:(cnt = ExaModels.hdrpass(t1, t2.args[$k], comp, y1, y2, o2, cnt, adj * t2.g[$k])) for k in 1:N]
return quote
$(stmts...)
@@ -285,15 +285,15 @@ end
# Disambiguate: SecondAdjointNodeN × SecondAdjointNodeN
@generated function hdrpass(
- t1::ExaModels.SecondAdjointNodeN{F1,N1},
- t2::ExaModels.SecondAdjointNodeN{F2,N2},
- comp,
- y1,
- y2,
- o2,
- cnt,
- adj,
-) where {F1,N1,F2,N2}
+ t1::ExaModels.SecondAdjointNodeN{F1, N1},
+ t2::ExaModels.SecondAdjointNodeN{F2, N2},
+ comp,
+ y1,
+ y2,
+ o2,
+ cnt,
+ adj,
+ ) where {F1, N1, F2, N2}
stmts = Expr[]
for i in 1:N1
for j in 1:N2
@@ -419,15 +419,15 @@ end
end
@generated function hrpass(
- t::ExaModels.SecondAdjointNodeN{F,N},
- comp,
- y1,
- y2,
- o2,
- cnt,
- adj,
- adj2,
-) where {F,N}
+ t::ExaModels.SecondAdjointNodeN{F, N},
+ comp,
+ y1,
+ y2,
+ o2,
+ cnt,
+ adj,
+ adj2,
+ ) where {F, N}
stmts = Expr[]
# diagonal terms: ∂f/∂xᵢ * ∂²cᵢ/∂x² + adj2 * (∂f/∂xᵢ)² * ∂cᵢ/∂x ⊗ ∂cᵢ/∂x
for k in 1:N
@@ -435,19 +435,27 @@ end
# simpler formula: row k, col k → offset = (k-1)*N - (k-1)*(k-2)/2 + 1
# = (k-1)*(2N - k + 2)/2 + 1 ... let's compute manually
idx = (k - 1) * N - (k - 1) * (k - 2) ÷ 2 + 1 # (k,k) entry, 1-based
- push!(stmts, :(cnt = ExaModels.hrpass(
- t.args[$k], comp, y1, y2, o2, cnt,
- adj * t.g[$k],
- adj2 * t.g[$k]^2 + adj * t.h[$idx],
- )))
+ push!(
+ stmts, :(
+ cnt = ExaModels.hrpass(
+ t.args[$k], comp, y1, y2, o2, cnt,
+ adj * t.g[$k],
+ adj2 * t.g[$k]^2 + adj * t.h[$idx],
+ )
+ )
+ )
end
# off-diagonal cross terms: (∂²f/∂xᵢ∂xⱼ) * ∂cᵢ/∂x ⊗ ∂cⱼ/∂x for i < j
- for i in 1:N, j in (i+1):N
+ for i in 1:N, j in (i + 1):N
idx_ij = (i - 1) * N - (i - 1) * (i - 2) ÷ 2 + (j - i) + 1 # (i,j) entry
- push!(stmts, :(cnt = ExaModels.hdrpass(
- t.args[$i], t.args[$j], comp, y1, y2, o2, cnt,
- adj2 * t.g[$i] * t.g[$j] + adj * t.h[$idx_ij],
- )))
+ push!(
+ stmts, :(
+ cnt = ExaModels.hdrpass(
+ t.args[$i], t.args[$j], comp, y1, y2, o2, cnt,
+ adj2 * t.g[$i] * t.g[$j] + adj * t.h[$idx_ij],
+ )
+ )
+ )
end
return quote
$(stmts...)
@@ -456,30 +464,38 @@ end
end
@generated function hrpass0(
- t::ExaModels.SecondAdjointNodeN{F,N},
- comp,
- y1,
- y2,
- o2,
- cnt,
- adj,
- adj2,
-) where {F,N}
+ t::ExaModels.SecondAdjointNodeN{F, N},
+ comp,
+ y1,
+ y2,
+ o2,
+ cnt,
+ adj,
+ adj2,
+ ) where {F, N}
stmts = Expr[]
for k in 1:N
idx = (k - 1) * N - (k - 1) * (k - 2) ÷ 2 + 1
- push!(stmts, :(cnt = ExaModels.hrpass(
- t.args[$k], comp, y1, y2, o2, cnt,
- adj * t.g[$k],
- adj2 * t.g[$k]^2 + adj * t.h[$idx],
- )))
+ push!(
+ stmts, :(
+ cnt = ExaModels.hrpass(
+ t.args[$k], comp, y1, y2, o2, cnt,
+ adj * t.g[$k],
+ adj2 * t.g[$k]^2 + adj * t.h[$idx],
+ )
+ )
+ )
end
- for i in 1:N, j in (i+1):N
+ for i in 1:N, j in (i + 1):N
idx_ij = (i - 1) * N - (i - 1) * (i - 2) ÷ 2 + (j - i) + 1
- push!(stmts, :(cnt = ExaModels.hdrpass(
- t.args[$i], t.args[$j], comp, y1, y2, o2, cnt,
- adj2 * t.g[$i] * t.g[$j] + adj * t.h[$idx_ij],
- )))
+ push!(
+ stmts, :(
+ cnt = ExaModels.hdrpass(
+ t.args[$i], t.args[$j], comp, y1, y2, o2, cnt,
+ adj2 * t.g[$i] * t.g[$j] + adj * t.h[$idx_ij],
+ )
+ )
+ )
end
return quote
$(stmts...)
diff --git a/src/jacobian.jl b/src/jacobian.jl
index 581dd38..4d2fae8 100644
--- a/src/jacobian.jl
+++ b/src/jacobian.jl
@@ -39,15 +39,15 @@ end
return cnt
end
@generated function jrpass(
- d::ExaModels.AdjointNodeN{F,N},
- comp,
- i,
- y1,
- y2,
- o1,
- cnt,
- adj,
-) where {F,N}
+ d::ExaModels.AdjointNodeN{F, N},
+ comp,
+ i,
+ y1,
+ y2,
+ o1,
+ cnt,
+ adj,
+ ) where {F, N}
stmts = [:(cnt = ExaModels.jrpass(d.args[$k], comp, i, y1, y2, o1, cnt, adj * d.g[$k])) for k in 1:N]
return quote
$(stmts...)
diff --git a/src/register.jl b/src/register.jl
index ac347eb..dfafa8f 100644
--- a/src/register.jl
+++ b/src/register.jl
@@ -34,12 +34,12 @@ macro register_multivariate(f, N_expr, grad, hess)
N = N_expr isa Integer ? N_expr : N_expr # handled at parse time
# Build (d1, d2, ..., dN) argument symbols
arg_syms = [Symbol("_d", k) for k in 1:N]
- arg_sym_types_node = [:($(Symbol("D",k))<:ExaModels.AbstractNode) for k in 1:N]
- arg_sym_types_adjoint = [:($(Symbol("D",k))<:ExaModels.AbstractAdjointNode) for k in 1:N]
- arg_sym_types_sadjoint = [:($(Symbol("D",k))<:ExaModels.AbstractSecondAdjointNode) for k in 1:N]
- arg_decls_node = [:($(arg_syms[k])::$(Symbol("D",k))) for k in 1:N]
- arg_decls_adjoint = [:($(arg_syms[k])::$(Symbol("D",k))) for k in 1:N]
- arg_decls_sadjoint = [:($(arg_syms[k])::$(Symbol("D",k))) for k in 1:N]
+ arg_sym_types_node = [:($(Symbol("D", k)) <: ExaModels.AbstractNode) for k in 1:N]
+ arg_sym_types_adjoint = [:($(Symbol("D", k)) <: ExaModels.AbstractAdjointNode) for k in 1:N]
+ arg_sym_types_sadjoint = [:($(Symbol("D", k)) <: ExaModels.AbstractSecondAdjointNode) for k in 1:N]
+ arg_decls_node = [:($(arg_syms[k])::$(Symbol("D", k))) for k in 1:N]
+ arg_decls_adjoint = [:($(arg_syms[k])::$(Symbol("D", k))) for k in 1:N]
+ arg_decls_sadjoint = [:($(arg_syms[k])::$(Symbol("D", k))) for k in 1:N]
# Primal values of each child (for calling f, grad, hess at the evaluation point)
xvals = [:($(arg_syms[k]).x) for k in 1:N]
@@ -49,14 +49,14 @@ macro register_multivariate(f, N_expr, grad, hess)
# 1. AbstractNode overload — builds symbolic graph node
if !hasmethod($f, Tuple{$(fill(:(ExaModels.AbstractNode), N)...)})
@inline function $f($(arg_decls_node...)) where {$(arg_sym_types_node...)}
- ExaModels.NodeN($f, ($(arg_syms...),))
+ return ExaModels.NodeN($f, ($(arg_syms...),))
end
end
# 2. AbstractAdjointNode overload — first-order AD.
# grad(x1,...,xN) returns NTuple{N} — no heap allocation, GPU-safe.
@inline function $f($(arg_decls_adjoint...)) where {$(arg_sym_types_adjoint...)}
- ExaModels.AdjointNodeN(
+ return ExaModels.AdjointNodeN(
$f,
$f($(xvals...)),
$grad($(xvals...)),
@@ -68,7 +68,7 @@ macro register_multivariate(f, N_expr, grad, hess)
# grad(x1,...,xN) and hess(x1,...,xN) return NTuples —
# no heap allocation, GPU-safe.
@inline function $f($(arg_decls_sadjoint...)) where {$(arg_sym_types_sadjoint...)}
- ExaModels.SecondAdjointNodeN(
+ return ExaModels.SecondAdjointNodeN(
$f,
$f($(xvals...)),
$grad($(xvals...)),
@@ -78,7 +78,7 @@ macro register_multivariate(f, N_expr, grad, hess)
end
# 4. Evaluation overload for NodeN
- @inline (n::ExaModels.NodeN{typeof($f),$N})(i, x, θ) =
+ @inline (n::ExaModels.NodeN{typeof($f), $N})(i, x, θ) =
$f($([:(n.args[$k](i, x, θ)) for k in 1:N]...))
end,
)
diff --git a/test/NLPTest/multivariate_test.jl b/test/NLPTest/multivariate_test.jl
index 6fd05b8..2173741 100644
--- a/test/NLPTest/multivariate_test.jl
+++ b/test/NLPTest/multivariate_test.jl
@@ -27,7 +27,7 @@ _hess_g2(x, y) = (-sin(x) * cos(y), -cos(x) * sin(y), -sin(x) * cos(y))
"""
Finite-difference gradient of ExaModel objective or single constraint.
"""
-function fd_gradient(m, x0; h = 1e-5)
+function fd_gradient(m, x0; h = 1.0e-5)
n = length(x0)
g = zeros(n)
f0 = NLPModels.obj(m, x0)
@@ -38,7 +38,7 @@ function fd_gradient(m, x0; h = 1e-5)
return g
end
-function fd_jacobian(m, x0; h = 1e-5)
+function fd_jacobian(m, x0; h = 1.0e-5)
n = length(x0)
ncon = m.meta.ncon
J = zeros(ncon, n)
@@ -50,7 +50,7 @@ function fd_jacobian(m, x0; h = 1e-5)
return J
end
-function fd_hessian_lag(m, x0, y0; h = 1e-5)
+function fd_hessian_lag(m, x0, y0; h = 1.0e-5)
n = length(x0)
H = zeros(n, n)
g0 = NLPModels.grad(m, x0) .+ NLPModels.jtprod(m, x0, y0)
@@ -71,21 +71,21 @@ Test @register_multivariate with a 3-argument quadratic function as objective.
Compares ExaModels gradient/Hessian to finite differences.
"""
function test_multivariate_objective(backend)
- @testset "register_multivariate: 3-arg objective" begin
+ return @testset "register_multivariate: 3-arg objective" begin
N = 6
c = ExaCore(; backend = backend)
x = variable(c, N; start = [Float64(i) for i in 1:N])
# objective: sum_i _f3(x[i], x[i+1], x[i+2])
- objective(c, _f3(x[i], x[i + 1], x[i + 2]) for i in 1:N-2)
+ objective(c, _f3(x[i], x[i + 1], x[i + 2]) for i in 1:(N - 2))
m = ExaModel(c)
w = WrapperNLPModel(m)
x0 = copy(w.meta.x0)
g_exa = NLPModels.grad(w, x0)
- g_fd = fd_gradient(w, x0)
- @test g_exa ≈ g_fd atol = 1e-4
+ g_fd = fd_gradient(w, x0)
+ @test g_exa ≈ g_fd atol = 1.0e-4
# Hessian (no constraints, so y = [])
hI = zeros(Int, w.meta.nnzh)
@@ -103,7 +103,7 @@ function test_multivariate_objective(backend)
end
H_fd = fd_hessian_lag(w, x0, Float64[])
- @test H_exa ≈ H_fd atol = 1e-3
+ @test H_exa ≈ H_fd atol = 1.0e-3
end
end
@@ -112,7 +112,7 @@ Test @register_multivariate with a 2-argument function as constraint.
Compares ExaModels Jacobian to finite differences.
"""
function test_multivariate_constraint(backend)
- @testset "register_multivariate: 2-arg constraint" begin
+ return @testset "register_multivariate: 2-arg constraint" begin
N = 4
c = ExaCore(; backend = backend)
x = variable(c, N; start = [0.5 + 0.1 * Float64(i) for i in 1:N])
@@ -120,7 +120,7 @@ function test_multivariate_constraint(backend)
# constraint: _g2(x[i], x[i+1]) ∈ [-1, 1] for i = 1:N-1
constraint(
c,
- _g2(x[i], x[i + 1]) for i in 1:N-1;
+ _g2(x[i], x[i + 1]) for i in 1:(N - 1);
lcon = -ones(N - 1),
ucon = ones(N - 1),
)
@@ -145,7 +145,7 @@ function test_multivariate_constraint(backend)
end
J_fd = fd_jacobian(w, x0)
- @test J_exa ≈ J_fd atol = 1e-4
+ @test J_exa ≈ J_fd atol = 1.0e-4
# Hessian of Lagrangian
y0 = randn(N - 1)
@@ -164,7 +164,7 @@ function test_multivariate_constraint(backend)
end
H_fd = fd_hessian_lag(w, x0, y0)
- @test H_exa ≈ H_fd atol = 1e-3
+ @test H_exa ≈ H_fd atol = 1.0e-3
end
end
@@ -173,7 +173,7 @@ Test that @register_multivariate interoperates correctly with ExaModels'
native symbolic operations (Node1, Node2) in the same expression.
"""
function test_multivariate_mixed_expression(backend)
- @testset "register_multivariate: mixed with native ops" begin
+ return @testset "register_multivariate: mixed with native ops" begin
N = 6
c = ExaCore(; backend = backend)
x = variable(c, N; start = [0.3 * Float64(i) for i in 1:N])
@@ -182,7 +182,7 @@ function test_multivariate_mixed_expression(backend)
# This tests that NodeN can be a child of a Node1/Node2 and vice versa.
objective(
c,
- _f3(x[i], sin(x[i + 1]), x[i + 2]^2) for i in 1:N-2
+ _f3(x[i], sin(x[i + 1]), x[i + 2]^2) for i in 1:(N - 2)
)
m = ExaModel(c)
@@ -190,8 +190,8 @@ function test_multivariate_mixed_expression(backend)
x0 = copy(w.meta.x0)
g_exa = NLPModels.grad(w, x0)
- g_fd = fd_gradient(w, x0)
- @test g_exa ≈ g_fd atol = 1e-4
+ g_fd = fd_gradient(w, x0)
+ @test g_exa ≈ g_fd atol = 1.0e-4
end
end
@@ -201,5 +201,5 @@ Run all @register_multivariate tests.
function test_multivariate(backend)
test_multivariate_objective(backend)
test_multivariate_constraint(backend)
- test_multivariate_mixed_expression(backend)
+ return test_multivariate_mixed_expression(backend)
end |
Author
|
closing in favor of #243 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
closes #239