Skip to content

[WIP] Add register_multivariate#241

Closed
andrewrosemberg wants to merge 2 commits intoexanauts:mainfrom
andrewrosemberg:ar/mimo_operator
Closed

[WIP] Add register_multivariate#241
andrewrosemberg wants to merge 2 commits intoexanauts:mainfrom
andrewrosemberg:ar/mimo_operator

Conversation

@andrewrosemberg
Copy link

closes #239

@github-actions
Copy link
Contributor

github-actions bot commented Mar 3, 2026

Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic main) to apply these changes.

Click here to view the suggested changes.
diff --git a/src/gradient.jl b/src/gradient.jl
index 55bbb36..8812f9e 100644
--- a/src/gradient.jl
+++ b/src/gradient.jl
@@ -24,7 +24,7 @@ end
     @inbounds y[d.i] += adj
     nothing
 end
-@generated function drpass(d::ExaModels.AdjointNodeN{F,N}, y, adj) where {F,N}
+@generated function drpass(d::ExaModels.AdjointNodeN{F, N}, y, adj) where {F, N}
     stmts = [:(ExaModels.drpass(d.args[$k], y, adj * d.g[$k])) for k in 1:N]
     return quote
         $(stmts...)
@@ -92,13 +92,13 @@ end
     return cnt
 end
 @generated function grpass(
-    d::ExaModels.AdjointNodeN{F,N},
-    comp,
-    y,
-    o1,
-    cnt,
-    adj,
-) where {F,N}
+        d::ExaModels.AdjointNodeN{F, N},
+        comp,
+        y,
+        o1,
+        cnt,
+        adj,
+    ) where {F, N}
     stmts = [:(cnt = ExaModels.grpass(d.args[$k], comp, y, o1, cnt, adj * d.g[$k])) for k in 1:N]
     return quote
         $(stmts...)
diff --git a/src/graph.jl b/src/graph.jl
index 3122063..480280e 100644
--- a/src/graph.jl
+++ b/src/graph.jl
@@ -322,12 +322,12 @@ A node with N children for symbolic expression tree (multivariate registered fun
 # Fields:
 - `args::Args`: tuple of N children nodes
 """
-struct NodeN{F,N,Args} <: AbstractNode
+struct NodeN{F, N, Args} <: AbstractNode
     args::Args
 end
 
-@inline NodeN(f::F, args::Args) where {F,N,Args<:NTuple{N,Any}} =
-    NodeN{F,N,Args}(args)
+@inline NodeN(f::F, args::Args) where {F, N, Args <: NTuple{N, Any}} =
+    NodeN{F, N, Args}(args)
 
 """
     AdjointNodeN{F, N, T, Args}
@@ -339,14 +339,14 @@ A node with N children for first-order forward pass tree (multivariate registere
 - `g::NTuple{N,T}`: first-order sensitivities (∂f/∂xᵢ) for each argument
 - `args::Args`: tuple of N children `AbstractAdjointNode`s
 """
-struct AdjointNodeN{F,N,T,Args} <: AbstractAdjointNode
+struct AdjointNodeN{F, N, T, Args} <: AbstractAdjointNode
     x::T
-    g::NTuple{N,T}
+    g::NTuple{N, T}
     args::Args
 end
 
-@inline AdjointNodeN(f::F, x::T, g::NTuple{N,T}, args::Args) where {F,N,T,Args} =
-    AdjointNodeN{F,N,T,Args}(x, g, args)
+@inline AdjointNodeN(f::F, x::T, g::NTuple{N, T}, args::Args) where {F, N, T, Args} =
+    AdjointNodeN{F, N, T, Args}(x, g, args)
 
 """
     SecondAdjointNodeN{F, N, T, Args}
@@ -360,20 +360,20 @@ A node with N children for second-order forward pass tree (multivariate register
   K = N*(N+1)÷2. Entry (i,j) with i≤j is at position i*(2N-i+1)÷2 + (j-i) + 1.
 - `args::Args`: tuple of N children `AbstractSecondAdjointNode`s
 """
-struct SecondAdjointNodeN{F,N,K,T,Args} <: AbstractSecondAdjointNode
+struct SecondAdjointNodeN{F, N, K, T, Args} <: AbstractSecondAdjointNode
     x::T
-    g::NTuple{N,T}
-    h::NTuple{K,T}
+    g::NTuple{N, T}
+    h::NTuple{K, T}
     args::Args
 end
 
 @inline SecondAdjointNodeN(
     f::F,
     x::T,
-    g::NTuple{N,T},
-    h::NTuple{K,T},
+    g::NTuple{N, T},
+    h::NTuple{K, T},
     args::Args,
-) where {F,N,K,T,Args} = SecondAdjointNodeN{F,N,K,T,Args}(x, g, h, args)
+) where {F, N, K, T, Args} = SecondAdjointNodeN{F, N, K, T, Args}(x, g, h, args)
 
 # Upper-triangular index helper: (i,j) with 1-based i≤j
 @inline _hess_index(i, j, N) = (i - 1) * N - (i - 1) * (i - 2) ÷ 2 + (j - i) + 1
diff --git a/src/hessian.jl b/src/hessian.jl
index 16c4be2..7ab89ee 100644
--- a/src/hessian.jl
+++ b/src/hessian.jl
@@ -250,15 +250,15 @@ end
 
 # hdrpass for SecondAdjointNodeN paired with any other SecondAdjointNode type
 @generated function hdrpass(
-    t1::ExaModels.SecondAdjointNodeN{F,N},
-    t2::T2,
-    comp,
-    y1,
-    y2,
-    o2,
-    cnt,
-    adj,
-) where {F,N,T2<:AbstractSecondAdjointNode}
+        t1::ExaModels.SecondAdjointNodeN{F, N},
+        t2::T2,
+        comp,
+        y1,
+        y2,
+        o2,
+        cnt,
+        adj,
+    ) where {F, N, T2 <: AbstractSecondAdjointNode}
     stmts = [:(cnt = ExaModels.hdrpass(t1.args[$k], t2, comp, y1, y2, o2, cnt, adj * t1.g[$k])) for k in 1:N]
     return quote
         $(stmts...)
@@ -267,15 +267,15 @@ end
 end
 
 @generated function hdrpass(
-    t1::T1,
-    t2::ExaModels.SecondAdjointNodeN{F,N},
-    comp,
-    y1,
-    y2,
-    o2,
-    cnt,
-    adj,
-) where {T1<:AbstractSecondAdjointNode,F,N}
+        t1::T1,
+        t2::ExaModels.SecondAdjointNodeN{F, N},
+        comp,
+        y1,
+        y2,
+        o2,
+        cnt,
+        adj,
+    ) where {T1 <: AbstractSecondAdjointNode, F, N}
     stmts = [:(cnt = ExaModels.hdrpass(t1, t2.args[$k], comp, y1, y2, o2, cnt, adj * t2.g[$k])) for k in 1:N]
     return quote
         $(stmts...)
@@ -285,15 +285,15 @@ end
 
 # Disambiguate: SecondAdjointNodeN × SecondAdjointNodeN
 @generated function hdrpass(
-    t1::ExaModels.SecondAdjointNodeN{F1,N1},
-    t2::ExaModels.SecondAdjointNodeN{F2,N2},
-    comp,
-    y1,
-    y2,
-    o2,
-    cnt,
-    adj,
-) where {F1,N1,F2,N2}
+        t1::ExaModels.SecondAdjointNodeN{F1, N1},
+        t2::ExaModels.SecondAdjointNodeN{F2, N2},
+        comp,
+        y1,
+        y2,
+        o2,
+        cnt,
+        adj,
+    ) where {F1, N1, F2, N2}
     stmts = Expr[]
     for i in 1:N1
         for j in 1:N2
@@ -419,15 +419,15 @@ end
 end
 
 @generated function hrpass(
-    t::ExaModels.SecondAdjointNodeN{F,N},
-    comp,
-    y1,
-    y2,
-    o2,
-    cnt,
-    adj,
-    adj2,
-) where {F,N}
+        t::ExaModels.SecondAdjointNodeN{F, N},
+        comp,
+        y1,
+        y2,
+        o2,
+        cnt,
+        adj,
+        adj2,
+    ) where {F, N}
     stmts = Expr[]
     # diagonal terms: ∂f/∂xᵢ * ∂²cᵢ/∂x² + adj2 * (∂f/∂xᵢ)² * ∂cᵢ/∂x ⊗ ∂cᵢ/∂x
     for k in 1:N
@@ -435,19 +435,27 @@ end
         # simpler formula: row k, col k → offset = (k-1)*N - (k-1)*(k-2)/2 + 1
         # = (k-1)*(2N - k + 2)/2 + 1  ... let's compute manually
         idx = (k - 1) * N - (k - 1) * (k - 2) ÷ 2 + 1  # (k,k) entry, 1-based
-        push!(stmts, :(cnt = ExaModels.hrpass(
-            t.args[$k], comp, y1, y2, o2, cnt,
-            adj * t.g[$k],
-            adj2 * t.g[$k]^2 + adj * t.h[$idx],
-        )))
+        push!(
+            stmts, :(
+                cnt = ExaModels.hrpass(
+                    t.args[$k], comp, y1, y2, o2, cnt,
+                    adj * t.g[$k],
+                    adj2 * t.g[$k]^2 + adj * t.h[$idx],
+                )
+            )
+        )
     end
     # off-diagonal cross terms: (∂²f/∂xᵢ∂xⱼ) * ∂cᵢ/∂x ⊗ ∂cⱼ/∂x  for i < j
-    for i in 1:N, j in (i+1):N
+    for i in 1:N, j in (i + 1):N
         idx_ij = (i - 1) * N - (i - 1) * (i - 2) ÷ 2 + (j - i) + 1  # (i,j) entry
-        push!(stmts, :(cnt = ExaModels.hdrpass(
-            t.args[$i], t.args[$j], comp, y1, y2, o2, cnt,
-            adj2 * t.g[$i] * t.g[$j] + adj * t.h[$idx_ij],
-        )))
+        push!(
+            stmts, :(
+                cnt = ExaModels.hdrpass(
+                    t.args[$i], t.args[$j], comp, y1, y2, o2, cnt,
+                    adj2 * t.g[$i] * t.g[$j] + adj * t.h[$idx_ij],
+                )
+            )
+        )
     end
     return quote
         $(stmts...)
@@ -456,30 +464,38 @@ end
 end
 
 @generated function hrpass0(
-    t::ExaModels.SecondAdjointNodeN{F,N},
-    comp,
-    y1,
-    y2,
-    o2,
-    cnt,
-    adj,
-    adj2,
-) where {F,N}
+        t::ExaModels.SecondAdjointNodeN{F, N},
+        comp,
+        y1,
+        y2,
+        o2,
+        cnt,
+        adj,
+        adj2,
+    ) where {F, N}
     stmts = Expr[]
     for k in 1:N
         idx = (k - 1) * N - (k - 1) * (k - 2) ÷ 2 + 1
-        push!(stmts, :(cnt = ExaModels.hrpass(
-            t.args[$k], comp, y1, y2, o2, cnt,
-            adj * t.g[$k],
-            adj2 * t.g[$k]^2 + adj * t.h[$idx],
-        )))
+        push!(
+            stmts, :(
+                cnt = ExaModels.hrpass(
+                    t.args[$k], comp, y1, y2, o2, cnt,
+                    adj * t.g[$k],
+                    adj2 * t.g[$k]^2 + adj * t.h[$idx],
+                )
+            )
+        )
     end
-    for i in 1:N, j in (i+1):N
+    for i in 1:N, j in (i + 1):N
         idx_ij = (i - 1) * N - (i - 1) * (i - 2) ÷ 2 + (j - i) + 1
-        push!(stmts, :(cnt = ExaModels.hdrpass(
-            t.args[$i], t.args[$j], comp, y1, y2, o2, cnt,
-            adj2 * t.g[$i] * t.g[$j] + adj * t.h[$idx_ij],
-        )))
+        push!(
+            stmts, :(
+                cnt = ExaModels.hdrpass(
+                    t.args[$i], t.args[$j], comp, y1, y2, o2, cnt,
+                    adj2 * t.g[$i] * t.g[$j] + adj * t.h[$idx_ij],
+                )
+            )
+        )
     end
     return quote
         $(stmts...)
diff --git a/src/jacobian.jl b/src/jacobian.jl
index 581dd38..4d2fae8 100644
--- a/src/jacobian.jl
+++ b/src/jacobian.jl
@@ -39,15 +39,15 @@ end
     return cnt
 end
 @generated function jrpass(
-    d::ExaModels.AdjointNodeN{F,N},
-    comp,
-    i,
-    y1,
-    y2,
-    o1,
-    cnt,
-    adj,
-) where {F,N}
+        d::ExaModels.AdjointNodeN{F, N},
+        comp,
+        i,
+        y1,
+        y2,
+        o1,
+        cnt,
+        adj,
+    ) where {F, N}
     stmts = [:(cnt = ExaModels.jrpass(d.args[$k], comp, i, y1, y2, o1, cnt, adj * d.g[$k])) for k in 1:N]
     return quote
         $(stmts...)
diff --git a/src/register.jl b/src/register.jl
index ac347eb..dfafa8f 100644
--- a/src/register.jl
+++ b/src/register.jl
@@ -34,12 +34,12 @@ macro register_multivariate(f, N_expr, grad, hess)
     N = N_expr isa Integer ? N_expr : N_expr  # handled at parse time
     # Build (d1, d2, ..., dN) argument symbols
     arg_syms = [Symbol("_d", k) for k in 1:N]
-    arg_sym_types_node     = [:($(Symbol("D",k))<:ExaModels.AbstractNode) for k in 1:N]
-    arg_sym_types_adjoint  = [:($(Symbol("D",k))<:ExaModels.AbstractAdjointNode) for k in 1:N]
-    arg_sym_types_sadjoint = [:($(Symbol("D",k))<:ExaModels.AbstractSecondAdjointNode) for k in 1:N]
-    arg_decls_node     = [:($(arg_syms[k])::$(Symbol("D",k))) for k in 1:N]
-    arg_decls_adjoint  = [:($(arg_syms[k])::$(Symbol("D",k))) for k in 1:N]
-    arg_decls_sadjoint = [:($(arg_syms[k])::$(Symbol("D",k))) for k in 1:N]
+    arg_sym_types_node = [:($(Symbol("D", k)) <: ExaModels.AbstractNode) for k in 1:N]
+    arg_sym_types_adjoint = [:($(Symbol("D", k)) <: ExaModels.AbstractAdjointNode) for k in 1:N]
+    arg_sym_types_sadjoint = [:($(Symbol("D", k)) <: ExaModels.AbstractSecondAdjointNode) for k in 1:N]
+    arg_decls_node = [:($(arg_syms[k])::$(Symbol("D", k))) for k in 1:N]
+    arg_decls_adjoint = [:($(arg_syms[k])::$(Symbol("D", k))) for k in 1:N]
+    arg_decls_sadjoint = [:($(arg_syms[k])::$(Symbol("D", k))) for k in 1:N]
 
     # Primal values of each child (for calling f, grad, hess at the evaluation point)
     xvals = [:($(arg_syms[k]).x) for k in 1:N]
@@ -49,14 +49,14 @@ macro register_multivariate(f, N_expr, grad, hess)
             # 1. AbstractNode overload — builds symbolic graph node
             if !hasmethod($f, Tuple{$(fill(:(ExaModels.AbstractNode), N)...)})
                 @inline function $f($(arg_decls_node...)) where {$(arg_sym_types_node...)}
-                    ExaModels.NodeN($f, ($(arg_syms...),))
+                    return ExaModels.NodeN($f, ($(arg_syms...),))
                 end
             end
 
             # 2. AbstractAdjointNode overload — first-order AD.
             #    grad(x1,...,xN) returns NTuple{N} — no heap allocation, GPU-safe.
             @inline function $f($(arg_decls_adjoint...)) where {$(arg_sym_types_adjoint...)}
-                ExaModels.AdjointNodeN(
+                return ExaModels.AdjointNodeN(
                     $f,
                     $f($(xvals...)),
                     $grad($(xvals...)),
@@ -68,7 +68,7 @@ macro register_multivariate(f, N_expr, grad, hess)
             #    grad(x1,...,xN) and hess(x1,...,xN) return NTuples —
             #    no heap allocation, GPU-safe.
             @inline function $f($(arg_decls_sadjoint...)) where {$(arg_sym_types_sadjoint...)}
-                ExaModels.SecondAdjointNodeN(
+                return ExaModels.SecondAdjointNodeN(
                     $f,
                     $f($(xvals...)),
                     $grad($(xvals...)),
@@ -78,7 +78,7 @@ macro register_multivariate(f, N_expr, grad, hess)
             end
 
             # 4. Evaluation overload for NodeN
-            @inline (n::ExaModels.NodeN{typeof($f),$N})(i, x, θ) =
+            @inline (n::ExaModels.NodeN{typeof($f), $N})(i, x, θ) =
                 $f($([:(n.args[$k](i, x, θ)) for k in 1:N]...))
         end,
     )
diff --git a/test/NLPTest/multivariate_test.jl b/test/NLPTest/multivariate_test.jl
index 6fd05b8..2173741 100644
--- a/test/NLPTest/multivariate_test.jl
+++ b/test/NLPTest/multivariate_test.jl
@@ -27,7 +27,7 @@ _hess_g2(x, y) = (-sin(x) * cos(y), -cos(x) * sin(y), -sin(x) * cos(y))
 """
 Finite-difference gradient of ExaModel objective or single constraint.
 """
-function fd_gradient(m, x0; h = 1e-5)
+function fd_gradient(m, x0; h = 1.0e-5)
     n = length(x0)
     g = zeros(n)
     f0 = NLPModels.obj(m, x0)
@@ -38,7 +38,7 @@ function fd_gradient(m, x0; h = 1e-5)
     return g
 end
 
-function fd_jacobian(m, x0; h = 1e-5)
+function fd_jacobian(m, x0; h = 1.0e-5)
     n = length(x0)
     ncon = m.meta.ncon
     J = zeros(ncon, n)
@@ -50,7 +50,7 @@ function fd_jacobian(m, x0; h = 1e-5)
     return J
 end
 
-function fd_hessian_lag(m, x0, y0; h = 1e-5)
+function fd_hessian_lag(m, x0, y0; h = 1.0e-5)
     n = length(x0)
     H = zeros(n, n)
     g0 = NLPModels.grad(m, x0) .+ NLPModels.jtprod(m, x0, y0)
@@ -71,21 +71,21 @@ Test @register_multivariate with a 3-argument quadratic function as objective.
 Compares ExaModels gradient/Hessian to finite differences.
 """
 function test_multivariate_objective(backend)
-    @testset "register_multivariate: 3-arg objective" begin
+    return @testset "register_multivariate: 3-arg objective" begin
         N = 6
         c = ExaCore(; backend = backend)
         x = variable(c, N; start = [Float64(i) for i in 1:N])
 
         # objective: sum_i _f3(x[i], x[i+1], x[i+2])
-        objective(c, _f3(x[i], x[i + 1], x[i + 2]) for i in 1:N-2)
+        objective(c, _f3(x[i], x[i + 1], x[i + 2]) for i in 1:(N - 2))
 
         m = ExaModel(c)
         w = WrapperNLPModel(m)
         x0 = copy(w.meta.x0)
 
         g_exa = NLPModels.grad(w, x0)
-        g_fd  = fd_gradient(w, x0)
-        @test g_exa ≈ g_fd atol = 1e-4
+        g_fd = fd_gradient(w, x0)
+        @test g_exa ≈ g_fd atol = 1.0e-4
 
         # Hessian (no constraints, so y = [])
         hI = zeros(Int, w.meta.nnzh)
@@ -103,7 +103,7 @@ function test_multivariate_objective(backend)
         end
 
         H_fd = fd_hessian_lag(w, x0, Float64[])
-        @test H_exa ≈ H_fd atol = 1e-3
+        @test H_exa ≈ H_fd atol = 1.0e-3
     end
 end
 
@@ -112,7 +112,7 @@ Test @register_multivariate with a 2-argument function as constraint.
 Compares ExaModels Jacobian to finite differences.
 """
 function test_multivariate_constraint(backend)
-    @testset "register_multivariate: 2-arg constraint" begin
+    return @testset "register_multivariate: 2-arg constraint" begin
         N = 4
         c = ExaCore(; backend = backend)
         x = variable(c, N; start = [0.5 + 0.1 * Float64(i) for i in 1:N])
@@ -120,7 +120,7 @@ function test_multivariate_constraint(backend)
         # constraint: _g2(x[i], x[i+1]) ∈ [-1, 1]  for i = 1:N-1
         constraint(
             c,
-            _g2(x[i], x[i + 1]) for i in 1:N-1;
+            _g2(x[i], x[i + 1]) for i in 1:(N - 1);
             lcon = -ones(N - 1),
             ucon = ones(N - 1),
         )
@@ -145,7 +145,7 @@ function test_multivariate_constraint(backend)
         end
 
         J_fd = fd_jacobian(w, x0)
-        @test J_exa ≈ J_fd atol = 1e-4
+        @test J_exa ≈ J_fd atol = 1.0e-4
 
         # Hessian of Lagrangian
         y0 = randn(N - 1)
@@ -164,7 +164,7 @@ function test_multivariate_constraint(backend)
         end
 
         H_fd = fd_hessian_lag(w, x0, y0)
-        @test H_exa ≈ H_fd atol = 1e-3
+        @test H_exa ≈ H_fd atol = 1.0e-3
     end
 end
 
@@ -173,7 +173,7 @@ Test that @register_multivariate interoperates correctly with ExaModels'
 native symbolic operations (Node1, Node2) in the same expression.
 """
 function test_multivariate_mixed_expression(backend)
-    @testset "register_multivariate: mixed with native ops" begin
+    return @testset "register_multivariate: mixed with native ops" begin
         N = 6
         c = ExaCore(; backend = backend)
         x = variable(c, N; start = [0.3 * Float64(i) for i in 1:N])
@@ -182,7 +182,7 @@ function test_multivariate_mixed_expression(backend)
         # This tests that NodeN can be a child of a Node1/Node2 and vice versa.
         objective(
             c,
-            _f3(x[i], sin(x[i + 1]), x[i + 2]^2) for i in 1:N-2
+            _f3(x[i], sin(x[i + 1]), x[i + 2]^2) for i in 1:(N - 2)
         )
 
         m = ExaModel(c)
@@ -190,8 +190,8 @@ function test_multivariate_mixed_expression(backend)
         x0 = copy(w.meta.x0)
 
         g_exa = NLPModels.grad(w, x0)
-        g_fd  = fd_gradient(w, x0)
-        @test g_exa ≈ g_fd atol = 1e-4
+        g_fd = fd_gradient(w, x0)
+        @test g_exa ≈ g_fd atol = 1.0e-4
     end
 end
 
@@ -201,5 +201,5 @@ Run all @register_multivariate tests.
 function test_multivariate(backend)
     test_multivariate_objective(backend)
     test_multivariate_constraint(backend)
-    test_multivariate_mixed_expression(backend)
+    return test_multivariate_mixed_expression(backend)
 end

@andrewrosemberg
Copy link
Author

closing in favor of #243

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add @register_multivariate for n-argument custom operators

1 participant