diff --git a/src/reverse_mode.jl b/src/reverse_mode.jl index 5a90faa..45e5520 100644 --- a/src/reverse_mode.jl +++ b/src/reverse_mode.jl @@ -171,19 +171,32 @@ function _forward_eval( end elseif node.index == 3 # :* # Node `k` is not scalar, so we do matrix multiplication + # (or scalar `*` matrix scaling when one operand is scalar). if f.sizes.ndims[k] != 0 @assert N == 2 idx1 = first(children_indices) idx2 = last(children_indices) @inbounds ix1 = children_arr[idx1] @inbounds ix2 = children_arr[idx2] - v1 = _view_matrix(f.forward_storage, f.sizes, ix1) - v2 = _view_matrix(f.forward_storage, f.sizes, ix2) - out = _view_matrix(f.forward_storage, f.sizes, k) - LinearAlgebra.mul!(out, v1, v2) + out = _view_linear(f.forward_storage, f.sizes, k) + if f.sizes.ndims[ix1] == 0 + s = _getscalar(f.forward_storage, f.sizes, ix1) + v = _view_linear(f.forward_storage, f.sizes, ix2) + out .= s .* v + elseif f.sizes.ndims[ix2] == 0 + v = _view_linear(f.forward_storage, f.sizes, ix1) + s = _getscalar(f.forward_storage, f.sizes, ix2) + out .= v .* s + else + v1 = _view_matrix(f.forward_storage, f.sizes, ix1) + v2 = _view_matrix(f.forward_storage, f.sizes, ix2) + out_m = _view_matrix(f.forward_storage, f.sizes, k) + LinearAlgebra.mul!(out_m, v1, v2) + end # We deliberately don't write v1/v2 into partials_storage - # here: the matmul reverse branch reads forward_storage - # directly, so those writes were dead. + # here: the matmul (or scalar-scaling) reverse branch + # reads forward_storage directly, so those writes were + # dead. # Node `k` is scalar else tmp_prod = one(T) @@ -620,23 +633,55 @@ function _reverse_eval( op = DEFAULT_MULTIVARIATE_OPERATORS[node.index] if op == :* if f.sizes.ndims[k] != 0 - # Matrix multiplication: rev_v1 = rev_parent * v2', - # rev_v2 = v1' * rev_parent. Both v1 and v2 are read - # straight from forward_storage (the matmul forward - # branch deliberately doesn't snapshot them into - # partials_storage), and the reverse views are written - # in place. + # Matmul (or `scalar * matrix` scaling): rev_v1 = + # rev_parent * v2', rev_v2 = v1' * rev_parent. With + # a scalar operand, the result is `s .* M`, so + # rev[s] = sum(rev_parent .* M) and rev[M] = + # rev_parent .* s. Both v1 and v2 are read straight + # from forward_storage. idx1 = first(children_indices) idx2 = last(children_indices) ix1 = children_arr[idx1] ix2 = children_arr[idx2] - v1 = _view_matrix(f.forward_storage, f.sizes, ix1) - v2 = _view_matrix(f.forward_storage, f.sizes, ix2) - rev_parent = _view_matrix(f.reverse_storage, f.sizes, k) - rev_v1 = _view_matrix(f.reverse_storage, f.sizes, ix1) - rev_v2 = _view_matrix(f.reverse_storage, f.sizes, ix2) - LinearAlgebra.mul!(rev_v1, rev_parent, v2') - LinearAlgebra.mul!(rev_v2, v1', rev_parent) + rev_parent = _view_linear(f.reverse_storage, f.sizes, k) + ndims1 = f.sizes.ndims[ix1] + ndims2 = f.sizes.ndims[ix2] + if ndims1 == 0 && ndims2 != 0 + v2 = _view_linear(f.forward_storage, f.sizes, ix2) + s1 = _getscalar(f.forward_storage, f.sizes, ix1) + rev_v2 = + _view_linear(f.reverse_storage, f.sizes, ix2) + rev_v2 .= rev_parent .* s1 + _setscalar!( + f.reverse_storage, + LinearAlgebra.dot(rev_parent, v2), + f.sizes, + ix1, + ) + elseif ndims1 != 0 && ndims2 == 0 + v1 = _view_linear(f.forward_storage, f.sizes, ix1) + s2 = _getscalar(f.forward_storage, f.sizes, ix2) + rev_v1 = + _view_linear(f.reverse_storage, f.sizes, ix1) + rev_v1 .= rev_parent .* s2 + _setscalar!( + f.reverse_storage, + LinearAlgebra.dot(rev_parent, v1), + f.sizes, + ix2, + ) + else + v1 = _view_matrix(f.forward_storage, f.sizes, ix1) + v2 = _view_matrix(f.forward_storage, f.sizes, ix2) + rev_parent_m = + _view_matrix(f.reverse_storage, f.sizes, k) + rev_v1 = + _view_matrix(f.reverse_storage, f.sizes, ix1) + rev_v2 = + _view_matrix(f.reverse_storage, f.sizes, ix2) + LinearAlgebra.mul!(rev_v1, rev_parent_m, v2') + LinearAlgebra.mul!(rev_v2, v1', rev_parent_m) + end continue end elseif op == :vect diff --git a/test/JuMP.jl b/test/JuMP.jl index 533e631..75fa287 100644 --- a/test/JuMP.jl +++ b/test/JuMP.jl @@ -490,6 +490,43 @@ function test_broadcast_scalar_matrix_size_inference() return end +# Exercise the non-broadcasted `:*` scalar-times-matrix branch. +# +# JuMP's `Number * MatrixOfVariables` falls through to `Base.broadcasted` +# (no specialized method in `src/JuMP/operators.jl`), so the standard +# `c * W` syntax actually parses as broadcasted `:*`. To hit the +# non-broadcasted branch we build a `MatrixExpr` with `broadcasted = false` +# directly — the same shape `_matmul` produces, but with one scalar child +# instead of two matrix children. This is the path a hand-built +# `MOI.ScalarNonlinearFunction(:*, Any[c, W])` would land on. +function test_scalar_matrix_product_nonbroadcasted() + n, c = 2, 0.5 + model = Model() + @variable(model, W[1:n, 1:n], container = ArrayDiff.ArrayOfVariables) + mat_expr = ArrayDiff.MatrixExpr(:*, Any[c, W], (n, n), false) + @test mat_expr.head == :* + @test !mat_expr.broadcasted + @test size(mat_expr) == (n, n) + loss = sum(mat_expr) + mode = ArrayDiff.Mode() + ad = ArrayDiff.model(mode) + MOI.Nonlinear.set_objective(ad, JuMP.moi_function(loss)) + evaluator = MOI.Nonlinear.Evaluator( + ad, + mode, + JuMP.index.(JuMP.all_variables(model)), + ) + MOI.initialize(evaluator, [:Grad]) + x = Float64[1, 2, 3, 4] + # `sum(c * W) = c * sum(W)` + @test MOI.eval_objective(evaluator, x) ≈ c * sum(x) + g = zeros(n * n) + MOI.eval_objective_gradient(evaluator, g, x) + # `∂/∂W_ij sum(c * W) = c`, constant. + @test g ≈ fill(c, n * n) + return +end + end # module TestJuMP.runtests()