diff --git a/src/reverse_mode.jl b/src/reverse_mode.jl index faa14a5..7b3f051 100644 --- a/src/reverse_mode.jl +++ b/src/reverse_mode.jl @@ -136,11 +136,10 @@ function _forward_eval( tmp_sum = zero(T) for c_idx in children_indices ix = children_arr[c_idx] - _setindex!(f.partials_storage, one(T), f.sizes, ix, j) - f.partials_storage[ix] = one(T) - tmp_sum += _getindex(f.forward_storage, f.sizes, ix, j) + @j f.partials_storage[ix] = one(T) + tmp_sum += @j f.forward_storage[ix] end - _setindex!(f.forward_storage, tmp_sum, f.sizes, k, j) + @j f.forward_storage[k] = tmp_sum end elseif node.index == 2 # :- @assert N == 2 diff --git a/src/sizes.jl b/src/sizes.jl index b745156..cecadbd 100644 --- a/src/sizes.jl +++ b/src/sizes.jl @@ -49,6 +49,27 @@ function _setindex!(x, value, sizes::Sizes, k::Int, j) return x[sizes.storage_offset[k]+j] = value end +""" + @j(storage[node]) -> _getindex(storage, f.sizes, node, j) + @j(storage[node] = value) -> _setindex!(storage, value, f.sizes, node, j) + +This "at `j`" converts `getindex` and `setindex!` calls to access +the sub-array in a vector corresponding to a node at its `j`th index. +""" +macro j(expr) + if Meta.isexpr(expr, :(=)) && length(expr.args) == 2 + lhs, rhs = expr.args + @assert Meta.isexpr(lhs, :ref) + @assert length(expr.args) == 2 + return Expr(:call, :_setindex!, esc(lhs.args[1]), esc(rhs), esc(:(f.sizes)), esc(lhs.args[2]), esc(:j)) + elseif Meta.isexpr(expr, :ref) && length(expr.args) == 2 + arr, idx = expr.args + return Expr(:call, :_getindex, esc(arr), esc(:(f.sizes)), esc(idx), esc(:j)) + else + error("Unsupported expression `$expr`") + end +end + # /!\ Can only be called in decreasing `k` order function _add_size!(sizes::Sizes, k::Int, size::Tuple) sizes.ndims[k] = length(size)