Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/src/recursive_array_functions.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ and do not require that the RecursiveArrayTools types are used.
```@docs
recursivecopy
recursivecopy!
recursivecopyto!
vecvecapply
copyat_or_push!
```
3 changes: 2 additions & 1 deletion src/RecursiveArrayTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,8 @@ module RecursiveArrayTools
export DEFAULT_PLOT_FUNC, plottable_indices, plot_indices, getindepsym_defaultt,
interpret_vars, add_labels!, diffeq_to_arrays, solplot_vecs_and_labels

export recursivecopy, recursivecopy!, recursivefill!, vecvecapply, copyat_or_push!,
export recursivecopy, recursivecopy!, recursivecopyto!, recursivefill!, vecvecapply,
copyat_or_push!,
vecvec_to_mat, recursive_one, recursive_mean, recursive_bottom_eltype,
recursive_unitless_bottom_eltype, recursive_unitless_eltype

Expand Down
17 changes: 17 additions & 0 deletions src/array_partition.jl
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,23 @@ function recursivecopy!(
return A
end

function recursivecopyto!(A::ArrayPartition, B::ArrayPartition)
for (a, b) in zip(A.x, B.x)
recursivecopyto!(a, b)
end
return A
end

function recursivecopyto!(
A::ArrayPartition{T, S},
B::ArrayPartition{T, S}
) where {T, S <: Tuple{Vararg{AbstractVectorOfArray}}}
for i in eachindex(A.x, B.x)
recursivecopyto!(A.x[i], B.x[i])
end
return A
end

function recursive_mean(A::ArrayPartition)
n = npartitions(A)
if n == 0
Expand Down
66 changes: 65 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ recursivecopy!(b::AbstractArray{T, N}, a::AbstractArray{T, N})
```

A recursive `copy!` function. Acts like a `deepcopy!` on arrays of arrays, but
like `copy!` on arrays of scalars.
like `copy!` on arrays of scalars. Requires `b` and `a` to have matching `ndims`;
use [`recursivecopyto!`](@ref) for the `copyto!`-style linear-index variant that
allows mismatched shapes.
"""
function recursivecopy! end

Expand Down Expand Up @@ -105,6 +107,68 @@ function recursivecopy!(b::AbstractVectorOfArray, a::AbstractVectorOfArray)
return b
end

"""
```julia
recursivecopyto!(b::AbstractArray, a::AbstractArray)
```

A recursive `copyto!` function. Acts like a `deepcopy!` on arrays of arrays, but
like `copyto!` on arrays of scalars.

Unlike [`recursivecopy!`](@ref), this does not require `b` and `a` to have matching
`ndims` or axes; only that `length(b) >= length(a)`. Elements are copied in linear
(column-major) order, matching the semantics of `Base.copyto!`. Use this when
flattening/reshaping between destination and source is intended, e.g. copying a
`Vector` into a `Matrix` of the same total length.
"""
function recursivecopyto! end

function recursivecopyto!(b::AbstractArray, a::AbstractArray)
return copyto!(b, a)
end

function recursivecopyto!(
b::AbstractArray{T},
a::AbstractArray{T2}
) where {
T <: StaticArraysCore.StaticArray,
T2 <: StaticArraysCore.StaticArray,
}
@inbounds for (ib, ia) in zip(eachindex(b), eachindex(a))
# TODO: Check for `setindex!`` and use `copy!(b[i],a[i])` or `b[i] = a[i]`, see #19
b[ib] = copy(a[ia])
end
return b
end

function recursivecopyto!(
b::AbstractArray{T},
a::AbstractArray{T2}
) where {
T <: Union{AbstractArray, AbstractVectorOfArray},
T2 <: Union{AbstractArray, AbstractVectorOfArray},
}
if ArrayInterface.ismutable(T)
@inbounds for (ib, ia) in zip(eachindex(b), eachindex(a))
recursivecopyto!(b[ib], a[ia])
end
else
copyto!(b, a)
end
return b
end

function recursivecopyto!(b::AbstractVectorOfArray, a::AbstractVectorOfArray)
@inbounds for i in eachindex(b.u, a.u)
if ArrayInterface.ismutable(b.u[i]) || b.u[i] isa AbstractVectorOfArray
recursivecopyto!(b.u[i], a.u[i])
else
b.u[i] = recursivecopy(a.u[i])
end
end
return b
end

"""
```julia
recursivefill!(b::AbstractArray{T, N}, a)
Expand Down
81 changes: 81 additions & 0 deletions test/utils_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,87 @@ end
@test a.u[1][1] == 1.0
end

@testset "recursivecopyto!" begin
# Same-shape scalar arrays — should match copyto!
b = zeros(3)
a = [1.0, 2.0, 3.0]
recursivecopyto!(b, a)
@test b == a

b = zeros(2, 2)
a = [1.0 2.0; 3.0 4.0]
recursivecopyto!(b, a)
@test b == a

# Issue #589: Matrix ← Vector of matching length (rejected by recursivecopy!,
# allowed by recursivecopyto!).
b = zeros(2, 3)
a = collect(1.0:6.0)
recursivecopyto!(b, a)
@test b == reshape(a, 2, 3)
@test_throws MethodError recursivecopy!(b, a)

# Vector ← Matrix
b = zeros(6)
a = reshape(collect(1.0:6.0), 2, 3)
recursivecopyto!(b, a)
@test b == collect(1.0:6.0)

# Different-shape matrices, same total length
b = zeros(2, 3)
a = reshape(collect(1.0:6.0), 3, 2)
recursivecopyto!(b, a)
@test vec(b) == 1.0:6.0

# dst longer than src — tail untouched, matches Base.copyto!
b = ones(5)
a = [10.0, 20.0, 30.0]
recursivecopyto!(b, a)
@test b == [10.0, 20.0, 30.0, 1.0, 1.0]

# dst shorter than src — BoundsError, matches Base.copyto!
b = zeros(2)
a = [1.0, 2.0, 3.0]
@test_throws BoundsError recursivecopyto!(b, a)

# Nested: Vector of Vectors, matching shapes
a = [ones(3), 2 * ones(3)]
b = [zeros(3), zeros(3)]
recursivecopyto!(b, a)
@test b[1] == ones(3) && b[2] == 2 * ones(3)
# Verify deep copy semantics — mutating dst leaves src untouched
b[1][1] = 99.0
@test a[1][1] == 1.0

# Nested with shape mismatch at the leaves — inner copyto! handles it
a = [collect(1.0:6.0), collect(7.0:12.0)]
b = [zeros(2, 3), zeros(2, 3)]
recursivecopyto!(b, a)
@test b[1] == reshape(1.0:6.0, 2, 3)
@test b[2] == reshape(7.0:12.0, 2, 3)

# Static array element
a = [@SVector([1.0, 2.0]), @SVector([3.0, 4.0])]
b = [@SVector(zeros(2)), @SVector(zeros(2))]
recursivecopyto!(b, a)
@test b == a

# ArrayPartition with matching shapes (sanity — parity with recursivecopy!)
A = ArrayPartition(zeros(2), zeros(3))
B = ArrayPartition([1.0, 2.0], [3.0, 4.0, 5.0])
recursivecopyto!(A, B)
@test A.x[1] == [1.0, 2.0]
@test A.x[2] == [3.0, 4.0, 5.0]

# VectorOfArray
u1 = VA[zeros(MVector{2, Float64}), zeros(MVector{2, Float64})]
u2 = VA[fill(4, MVector{2, Float64}), 2 .* ones(MVector{2, Float64})]
recursivecopyto!(u1, u2)
@test u1.u[1] == [4.0, 4.0]
@test u1.u[2] == [2.0, 2.0]
@test u1.u[1] isa MVector
end

@testset "VectorOfArray similar with nested scalar leaves" begin
a = VA[ones(2), VA[1.0, 1.0]]
b = similar(a, Float64)
Expand Down
Loading