Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 64 additions & 19 deletions src/reverse_mode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,19 +171,32 @@ function _forward_eval(
end
elseif node.index == 3 # :*
# Node `k` is not scalar, so we do matrix multiplication
# (or scalar `*` matrix scaling when one operand is scalar).
if f.sizes.ndims[k] != 0
@assert N == 2
idx1 = first(children_indices)
idx2 = last(children_indices)
@inbounds ix1 = children_arr[idx1]
@inbounds ix2 = children_arr[idx2]
v1 = _view_matrix(f.forward_storage, f.sizes, ix1)
v2 = _view_matrix(f.forward_storage, f.sizes, ix2)
out = _view_matrix(f.forward_storage, f.sizes, k)
LinearAlgebra.mul!(out, v1, v2)
out = _view_linear(f.forward_storage, f.sizes, k)
if f.sizes.ndims[ix1] == 0
s = _getscalar(f.forward_storage, f.sizes, ix1)
v = _view_linear(f.forward_storage, f.sizes, ix2)
out .= s .* v
elseif f.sizes.ndims[ix2] == 0
v = _view_linear(f.forward_storage, f.sizes, ix1)
s = _getscalar(f.forward_storage, f.sizes, ix2)
out .= v .* s
else
v1 = _view_matrix(f.forward_storage, f.sizes, ix1)
v2 = _view_matrix(f.forward_storage, f.sizes, ix2)
out_m = _view_matrix(f.forward_storage, f.sizes, k)
LinearAlgebra.mul!(out_m, v1, v2)
end
# We deliberately don't write v1/v2 into partials_storage
# here: the matmul reverse branch reads forward_storage
# directly, so those writes were dead.
# here: the matmul (or scalar-scaling) reverse branch
# reads forward_storage directly, so those writes were
# dead.
# Node `k` is scalar
else
tmp_prod = one(T)
Expand Down Expand Up @@ -620,23 +633,55 @@ function _reverse_eval(
op = DEFAULT_MULTIVARIATE_OPERATORS[node.index]
if op == :*
if f.sizes.ndims[k] != 0
# Matrix multiplication: rev_v1 = rev_parent * v2',
# rev_v2 = v1' * rev_parent. Both v1 and v2 are read
# straight from forward_storage (the matmul forward
# branch deliberately doesn't snapshot them into
# partials_storage), and the reverse views are written
# in place.
# Matmul (or `scalar * matrix` scaling): rev_v1 =
# rev_parent * v2', rev_v2 = v1' * rev_parent. With
# a scalar operand, the result is `s .* M`, so
# rev[s] = sum(rev_parent .* M) and rev[M] =
# rev_parent .* s. Both v1 and v2 are read straight
# from forward_storage.
idx1 = first(children_indices)
idx2 = last(children_indices)
ix1 = children_arr[idx1]
ix2 = children_arr[idx2]
v1 = _view_matrix(f.forward_storage, f.sizes, ix1)
v2 = _view_matrix(f.forward_storage, f.sizes, ix2)
rev_parent = _view_matrix(f.reverse_storage, f.sizes, k)
rev_v1 = _view_matrix(f.reverse_storage, f.sizes, ix1)
rev_v2 = _view_matrix(f.reverse_storage, f.sizes, ix2)
LinearAlgebra.mul!(rev_v1, rev_parent, v2')
LinearAlgebra.mul!(rev_v2, v1', rev_parent)
rev_parent = _view_linear(f.reverse_storage, f.sizes, k)
ndims1 = f.sizes.ndims[ix1]
ndims2 = f.sizes.ndims[ix2]
if ndims1 == 0 && ndims2 != 0
v2 = _view_linear(f.forward_storage, f.sizes, ix2)
s1 = _getscalar(f.forward_storage, f.sizes, ix1)
rev_v2 =
_view_linear(f.reverse_storage, f.sizes, ix2)
rev_v2 .= rev_parent .* s1
_setscalar!(
f.reverse_storage,
LinearAlgebra.dot(rev_parent, v2),
f.sizes,
ix1,
)
elseif ndims1 != 0 && ndims2 == 0
v1 = _view_linear(f.forward_storage, f.sizes, ix1)
s2 = _getscalar(f.forward_storage, f.sizes, ix2)
rev_v1 =
_view_linear(f.reverse_storage, f.sizes, ix1)
rev_v1 .= rev_parent .* s2
_setscalar!(
f.reverse_storage,
LinearAlgebra.dot(rev_parent, v1),
f.sizes,
ix2,
)
else
v1 = _view_matrix(f.forward_storage, f.sizes, ix1)
v2 = _view_matrix(f.forward_storage, f.sizes, ix2)
rev_parent_m =
_view_matrix(f.reverse_storage, f.sizes, k)
rev_v1 =
_view_matrix(f.reverse_storage, f.sizes, ix1)
rev_v2 =
_view_matrix(f.reverse_storage, f.sizes, ix2)
LinearAlgebra.mul!(rev_v1, rev_parent_m, v2')
LinearAlgebra.mul!(rev_v2, v1', rev_parent_m)
end
continue
end
elseif op == :vect
Expand Down
37 changes: 37 additions & 0 deletions test/JuMP.jl
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,43 @@ function test_broadcast_scalar_matrix_size_inference()
return
end

# Exercise the non-broadcasted `:*` scalar-times-matrix branch.
#
# JuMP's `Number * MatrixOfVariables` falls through to `Base.broadcasted`
# (no specialized method in `src/JuMP/operators.jl`), so the standard
# `c * W` syntax actually parses as broadcasted `:*`. To hit the
# non-broadcasted branch we build a `MatrixExpr` with `broadcasted = false`
# directly — the same shape `_matmul` produces, but with one scalar child
# instead of two matrix children. This is the path a hand-built
# `MOI.ScalarNonlinearFunction(:*, Any[c, W])` would land on.
function test_scalar_matrix_product_nonbroadcasted()
n, c = 2, 0.5
model = Model()
@variable(model, W[1:n, 1:n], container = ArrayDiff.ArrayOfVariables)
mat_expr = ArrayDiff.MatrixExpr(:*, Any[c, W], (n, n), false)
@test mat_expr.head == :*
@test !mat_expr.broadcasted
@test size(mat_expr) == (n, n)
loss = sum(mat_expr)
mode = ArrayDiff.Mode()
ad = ArrayDiff.model(mode)
MOI.Nonlinear.set_objective(ad, JuMP.moi_function(loss))
evaluator = MOI.Nonlinear.Evaluator(
ad,
mode,
JuMP.index.(JuMP.all_variables(model)),
)
MOI.initialize(evaluator, [:Grad])
x = Float64[1, 2, 3, 4]
# `sum(c * W) = c * sum(W)`
@test MOI.eval_objective(evaluator, x) ≈ c * sum(x)
g = zeros(n * n)
MOI.eval_objective_gradient(evaluator, g, x)
# `∂/∂W_ij sum(c * W) = c`, constant.
@test g ≈ fill(c, n * n)
return
end

end # module

TestJuMP.runtests()
Loading