diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0a286d1..f98bf06 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,12 +1,15 @@ name: CI on: - pull_request: - branches: - - master push: branches: - master - tags: '*' + tags: ['*'] + pull_request: +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: only if it is a pull request build. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} jobs: test: name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} @@ -15,33 +18,26 @@ jobs: fail-fast: false matrix: version: - - '1.0' # Replace this with the minimum Julia version that your package supports. E.g. if your package requires Julia 1.5 or higher, change this to '1.5'. - - '1' # Leave this line unchanged. '1' will automatically expand to the latest stable 1.x release of Julia. - os: [ubuntu-latest, windows-latest, macOS-latest] + - '1' + os: + - ubuntu-latest + - windows-latest # Add this line to include the latest Windows system arch: - x64 steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - uses: julia-actions/setup-julia@v1 with: version: ${{ matrix.version }} arch: ${{ matrix.arch }} - - uses: actions/cache@v1 - env: - cache-name: cache-artifacts - with: - path: ~/.julia/artifacts - key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }} - restore-keys: | - ${{ runner.os }}-test-${{ env.cache-name }}- - ${{ runner.os }}-test- - ${{ runner.os }}- + - uses: julia-actions/cache@v1 - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 - uses: julia-actions/julia-processcoverage@v1 - - uses: codecov/codecov-action@v1 - with: - file: lcov.info + - uses: codecov/codecov-action@v5 + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} + docs: name: Documentation runs-on: ubuntu-latest diff --git a/Project.toml b/Project.toml index 8f5990b..1af2c8d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,14 +1,23 @@ name = "ShiftedArrays" uuid = "1277b4bf-5013-50f5-be3d-901d8477a67a" -repo = "https://github.com/JuliaArrays/ShiftedArrays.jl.git" version = "2.0.0" [compat] julia = "1" +CUDA = "5.2, 5.3, 5.4, 5.5, 5.6, 5.7" +Adapt = "3.7, 4.0, 4.1" + +[extensions] +CUDASupportExt = ["CUDA", "Adapt"] [extras] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" + +[weakdeps] +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" [targets] -test = ["Test", "AbstractFFTs"] +test = ["Test", "AbstractFFTs", "Random", "CUDA"] diff --git a/ext/CUDASupportExt.jl b/ext/CUDASupportExt.jl new file mode 100644 index 0000000..c60b41f --- /dev/null +++ b/ext/CUDASupportExt.jl @@ -0,0 +1,99 @@ +module CUDASupportExt +using CUDA +using Adapt +using ShiftedArrays +using Base + +get_base_arr(arr::CuArray) = arr +get_base_arr(arr::Array) = arr +function get_base_arr(arr::AbstractArray) + p = parent(arr) + return (p === arr) ? arr : get_base_arr(parent(arr)) +end + +# define a number of Union types to not repeat all definitions for each type +AllShiftedTypeCu{N, CD} = Union{CircShiftedArray{<:Any,<:Any,<:CuArray{<:Any,N,CD}}, + ShiftedArray{<:Any,<:Any,<:Any,<:CuArray{<:Any,N,CD}}} +AllShiftedTypeCuG{N, CD} = Union{AllShiftedTypeCu{N, CD}, CircShiftedArray{<:Any,<:Any,<:AllShiftedTypeCu{N,CD}}, + ShiftedArray{<:Any,<:Any,<:Any,<:AllShiftedTypeCu{N,CD}}} +AllSubArrayTypeCu{N, CD} = Union{SubArray{<:Any, <:Any, <:AllShiftedTypeCuG{N,CD}, <:Any, <:Any}, + Base.ReshapedArray{<:Any, <:Any, <:AllShiftedTypeCuG{N,CD}, <:Any}, + SubArray{<:Any, <:Any, <:Base.ReshapedArray{<:Any, <:Any, <:AllShiftedTypeCuG{N,CD}, <:Any}, <:Any, <:Any}} +AllShiftedAndViewsCu{N, CD} = Union{AllShiftedTypeCuG{N, CD}, AllSubArrayTypeCu{N, CD}} + +Adapt.adapt_structure(to, x::CircShiftedArray{T, N, S}) where {T, N, S} = CircShiftedArray(adapt(to, parent(x)), shifts(x)); +Adapt.adapt_structure(to, x::ShiftedArray{T, V, N, S}) where {T, V, N, S} = ShiftedArray(adapt(to, parent(x)), shifts(x), default=ShiftedArrays.default(x)); + +function Base.Broadcast.BroadcastStyle(::Type{T}) where {N, CD, T<:AllShiftedTypeCu{N, CD}} + CUDA.CuArrayStyle{N,CD}() +end + +# Define the BroadcastStyle for SubArray of MutableShiftedArray with CuArray + +function Base.Broadcast.BroadcastStyle(::Type{T}) where {N, CD, T<:AllSubArrayTypeCu{N, CD}} + CUDA.CuArrayStyle{N,CD}() +end + +function Base.copy(s::AllShiftedAndViewsCu) + res = similar(get_base_arr(s), eltype(s), size(s)); + res .= s + return res +end + +function Base.collect(x::AllShiftedAndViewsCu) + return copy(x) # stay on the GPU +end + +function Base.Array(x::AllShiftedAndViewsCu) + return Array(copy(x)) # remove from GPU +end + +function Base.:(==)(x::AllShiftedAndViewsCu, y::AbstractArray) + return all(x .== y) +end + +function Base.:(==)(y::AbstractArray, x::AllShiftedAndViewsCu) + return all(x .== y) +end + +function Base.:(==)(x::AllShiftedAndViewsCu, y::AllShiftedAndViewsCu) + return all(x .== y) +end + +function Base.isapprox(x::AllShiftedAndViewsCu, y::AbstractArray; atol=0, rtol=atol>0 ? 0 : sqrt(eps(real(eltype(x)))), va...) + atol = (atol != 0) ? atol : rtol * maximum(abs.(x)) + return all(abs.(x .- y) .<= atol) +end + +function Base.isapprox(y::AbstractArray, x::AllShiftedAndViewsCu; atol=0, rtol=atol>0 ? 0 : sqrt(eps(real(eltype(x)))), va...) + atol = (atol != 0) ? atol : rtol * maximum(abs.(x)) + return all(abs.(x .- y) .<= atol) +end + +function Base.isapprox(x::AllShiftedAndViewsCu, y::AllShiftedAndViewsCu; atol=0, rtol=atol>0 ? 0 : sqrt(eps(real(eltype(x)))), va...) + atol = (atol != 0) ? atol : rtol * maximum(abs.(x)) + return all(abs.(x .- y) .<= atol) +end + +function Base.show(io::IO, mm::MIME"text/plain", cs::AllShiftedAndViewsCu) + CUDA.@allowscalar invoke(Base.show, Tuple{IO, typeof(mm), AbstractArray}, io, mm, cs) +end + +# This version is needed to deal with range access of wrapped CuArrays. +# ShiftedVector(cu([1,2,3,4,5]))[2:3] +@inline function Base.getindex(s::AllShiftedTypeCu{N, CD}, x::Vararg{Union{AbstractRange, Int}, N}) where {N, CD} + v = @view s[x...] + res = similar(s.parent, eltype(s), size(v)) + res .= v +end + +# This specializations are to ensure that true single element accesses generate an error, if allowscalar has not be specified. +@inline function Base.getindex(s::ShiftedArray{A,B,C, <:CuArray{<:Any,N,CD}}, x::Vararg{Int, N}) where {A,B,C, N,CD} + invoke(ShiftedArrays.getindex, Tuple{ShiftedArray{A,B,C,<:AbstractArray}, ntuple((_)->Int, N)...}, s, x...) +end + +@inline function Base.getindex(s::CircShiftedArray{A,B,<:CuArray{<:Any,N,CD}}, x::Vararg{Int, N}) where {A,B,N,CD} + invoke(ShiftedArrays.getindex, Tuple{CircShiftedArray{A,B,<:AbstractArray}, ntuple((_)->Int, N)...}, s, x...) +end + +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 4415ed5..0607c3e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,184 +1,238 @@ using ShiftedArrays, Test using AbstractFFTs +using Random +using CUDA -@testset "ShiftedVector" begin - v = [1, 3, 5, 4] - @test all(v .== ShiftedVector(v)) - sv = ShiftedVector(v, -1) - @test isequal(sv, ShiftedVector(v, (-1,))) - @test length(sv) == 4 - @test all(sv[1:3] .== [3, 5, 4]) - @test ismissing(sv[4]) - diff = v .- sv - @test isequal(diff, [-2, -2, 1, missing]) - @test shifts(sv) == (-1,) - svneg = ShiftedVector(v, -1, default = -100) - @test default(svneg) == -100 - @test copy(svneg) == coalesce.(sv, -100) - @test isequal(sv[1:3], Union{Int64, Missing}[3, 5, 4]) - svnest = ShiftedVector(ShiftedVector(v, 1), 2) - sv = ShiftedVector(v, 3) - @test sv === svnest - sv = ShiftedVector(v, 2, default = nothing) - sv1 = ShiftedVector(sv, 1) - sv2 = ShiftedVector(sv, 1, default = 0) - @test isequal(collect(sv1), [nothing, nothing, nothing, 1]) - @test isequal(collect(sv2), [0, nothing, nothing, 1]) -end - -@testset "ShiftedArray" begin - v = reshape(1:16, 4, 4) - @test all(v .== ShiftedArray(v)) - sv = ShiftedArray(v, (-2, 0)) - @test length(sv) == 16 - @test sv[1, 3] == 11 - @test ismissing(sv[3, 3]) - @test shifts(sv) == (-2,0) - @test isequal(sv, ShiftedArray(v, -2)) - @test isequal(@inferred(ShiftedArray(v, (2,))), @inferred(ShiftedArray(v, 2))) - @test isequal(@inferred(ShiftedArray(v)), @inferred(ShiftedArray(v, (0, 0)))) - s = ShiftedArray(v, (0, -2)) - @test isequal(collect(s), [ 9 13 missing missing; - 10 14 missing missing; - 11 15 missing missing; - 12 16 missing missing]) - sneg = ShiftedArray(v, (0, -2), default = -100) - @test all(sneg .== coalesce.(s, default(sneg))) - @test checkbounds(Bool, sv, 2, 2) - @test !checkbounds(Bool, sv, 123, 123) - svnest = ShiftedArray(ShiftedArray(v, (1, 1)), 2) - sv = ShiftedArray(v, (3, 1)) - @test sv === svnest - sv = ShiftedArray(v, 2, default = nothing) - sv1 = ShiftedArray(sv, (1, 1)) - sv2 = ShiftedArray(sv, (1, 1), default = 0) - @test isequal(collect(sv1), [nothing nothing nothing nothing - nothing nothing nothing nothing - nothing nothing nothing nothing - nothing 1 5 9 ]) - @test isequal(collect(sv2), [0 0 0 0 - 0 nothing nothing nothing - 0 nothing nothing nothing - 0 1 5 9 ]) -end - -@testset "padded_tuple" begin - v = rand(2, 2) - @test (1, 0) == @inferred ShiftedArrays.padded_tuple(v, 1) - @test (0, 0) == @inferred ShiftedArrays.padded_tuple(v, ()) - @test (3, 0) == @inferred ShiftedArrays.padded_tuple(v, (3,)) - @test (1, 5) == @inferred ShiftedArrays.padded_tuple(v, (1, 5)) -end - -@testset "bringwithin" begin - @test ShiftedArrays.bringwithin(1, 1:10) == 1 - @test ShiftedArrays.bringwithin(0, 1:10) == 10 - @test ShiftedArrays.bringwithin(-1, 1:10) == 9 - - # test to check for offset axes - @test ShiftedArrays.bringwithin(5, 5:10) == 5 - @test ShiftedArrays.bringwithin(4, 5:10) == 10 -end - -@testset "CircShiftedVector" begin - v = [1, 3, 5, 4] - @test all(v .== CircShiftedVector(v)) - sv = CircShiftedVector(v, -1) - @test isequal(sv, CircShiftedVector(v, (-1,))) - @test length(sv) == 4 - @test all(sv .== [3, 5, 4, 1]) - diff = v .- sv - @test diff == [-2, -2, 1, 3] - @test shifts(sv) == (3,) - sv2 = CircShiftedVector(v, 1) - diff = v .- sv2 - @test copy(sv2) == [4, 1, 3, 5] - @test all(CircShiftedVector(v, 1) .== circshift(v, 1)) - sv[2] = 0 - @test collect(sv) == [3, 0, 4, 1] - @test v == [1, 3, 0, 4] - sv[3] = 12 - @test collect(sv) == [3, 0, 12, 1] - @test v == [1, 3, 0, 12] - @test sv === setindex!(sv, 12, 3) - @test checkbounds(Bool, sv, 2) - @test !checkbounds(Bool, sv, 123) - sv = CircShiftedArray(v, 3) - svnest = CircShiftedArray(CircShiftedArray(v, 2), 1) - @test sv === svnest -end +Random.seed!(42) -@testset "CircShiftedArray" begin - v = reshape(1:16, 4, 4) - @test all(v .== CircShiftedArray(v)) - sv = CircShiftedArray(v, (-2, 0)) - @test length(sv) == 16 - @test sv[1, 3] == 11 - @test shifts(sv) == (2, 0) - @test isequal(sv, CircShiftedArray(v, -2)) - @test isequal(@inferred(CircShiftedArray(v, 2)), @inferred(CircShiftedArray(v, (2,)))) - @test isequal(@inferred(CircShiftedArray(v)), @inferred(CircShiftedArray(v, (0, 0)))) - s = CircShiftedArray(v, (0, 2)) - @test isequal(collect(s), [ 9 13 1 5; - 10 14 2 6; - 11 15 3 7; - 12 16 4 8]) - sv = CircShiftedArray(v, 3) - svnest = CircShiftedArray(CircShiftedArray(v, 2), 1) - @test sv === svnest +function opt_cu(img, use_cuda) + if (use_cuda) + CuArray(img) + else + img + end end -@testset "circshift" begin - v = reshape(1:16, 4, 4) - @test all(circshift(v, (1, -1)) .== ShiftedArrays.circshift(v, (1, -1))) - @test all(circshift(v, (1,)) .== ShiftedArrays.circshift(v, (1,))) - @test all(circshift(v, 3) .== ShiftedArrays.circshift(v, 3)) - sv = ShiftedArrays.circshift(v, 3) - svnest = ShiftedArrays.circshift(ShiftedArrays.circshift(v, 2), 1) - @test sv === svnest +function run_all_tests(use_cuda=false) + @testset "ShiftedVector" begin + v = [1, 3, 5, 4] + v = opt_cu(v, use_cuda); + # missing in a UnionType is not allowed for element-wise comparison of CuArrays with all + @test all(v .== ShiftedVector(v, default=0)) + sv = ShiftedVector(v, -1) + @test isequal(sv, ShiftedVector(v, (-1,))) + @test length(sv) == 4 + @test ismissing(sv[4]) + diff = v .- sv + @test isequal(diff, opt_cu([-2, -2, 1, missing], use_cuda)) + # missing in a UnionType is not allowed for element-wise comparison of CuArrays with all + @test shifts(sv) == (-1,) + svneg = ShiftedVector(v, -1, default = -100) + @test default(svneg) == -100 + @test copy(svneg) == coalesce.(sv, -100) + @test isequal(sv[1:3], opt_cu(Union{Int64, Missing}[3, 5, 4], use_cuda)) + sv = ShiftedVector(v, -1, default=0) + @test all(sv[1:3] .== opt_cu([3, 5, 4], use_cuda)) + svnest = ShiftedVector(ShiftedVector(v, 1), 2) + sv = ShiftedVector(v, 3) + @test sv === svnest + sv = ShiftedVector(v, 2, default = nothing) + sv1 = ShiftedVector(sv, 1) + sv2 = ShiftedVector(sv, 1, default = 0) + @test isequal(collect(sv1), opt_cu([nothing, nothing, nothing, 1], use_cuda)) + @test isequal(collect(sv2), opt_cu([0, nothing, nothing, 1], use_cuda)) + end + + @testset "ShiftedArray" begin + v = reshape(1:16, 4, 4) + v = opt_cu(v, use_cuda); + @test all(v .== ShiftedArray(v, default=0)) + sv = ShiftedArray(v, (-2, 0)) + @test length(sv) == 16 + CUDA.@allowscalar @test sv[1, 3] == 11 + @test ismissing(sv[3, 3]) + @test shifts(sv) == (-2,0) + @test isequal(sv, ShiftedArray(v, -2)) + @test isequal(@inferred(ShiftedArray(v, (2,))), @inferred(ShiftedArray(v, 2))) + @test isequal(@inferred(ShiftedArray(v)), @inferred(ShiftedArray(v, (0, 0)))) + s = ShiftedArray(v, (0, -2)) + @test isequal(collect(s), opt_cu([ 9 13 missing missing; + 10 14 missing missing; + 11 15 missing missing; + 12 16 missing missing], use_cuda)) + sneg = ShiftedArray(v, (0, -2), default = -100) + @test all(sneg .== coalesce.(s, default(sneg))) + @test checkbounds(Bool, sv, 2, 2) + @test !checkbounds(Bool, sv, 123, 123) + svnest = ShiftedArray(ShiftedArray(v, (1, 1)), 2) + sv = ShiftedArray(v, (3, 1)) + @test sv === svnest + sv = ShiftedArray(v, 2, default = nothing) + sv1 = ShiftedArray(sv, (1, 1)) + sv2 = ShiftedArray(sv, (1, 1), default = 0) + @test isequal(collect(sv1), opt_cu([nothing nothing nothing nothing + nothing nothing nothing nothing + nothing nothing nothing nothing + nothing 1 5 9 ], use_cuda)) + @test isequal(collect(sv2), opt_cu([0 0 0 0 + 0 nothing nothing nothing + 0 nothing nothing nothing + 0 1 5 9 ], use_cuda)) + end + + @testset "padded_tuple" begin + v = rand(2, 2) + v = opt_cu(v, use_cuda); + @test (1, 0) == @inferred ShiftedArrays.padded_tuple(v, 1) + @test (0, 0) == @inferred ShiftedArrays.padded_tuple(v, ()) + @test (3, 0) == @inferred ShiftedArrays.padded_tuple(v, (3,)) + @test (1, 5) == @inferred ShiftedArrays.padded_tuple(v, (1, 5)) + end + + @testset "bringwithin" begin + @test ShiftedArrays.bringwithin(1, 1:10) == 1 + @test ShiftedArrays.bringwithin(0, 1:10) == 10 + @test ShiftedArrays.bringwithin(-1, 1:10) == 9 + + # test to check for offset axes + @test ShiftedArrays.bringwithin(5, 5:10) == 5 + @test ShiftedArrays.bringwithin(4, 5:10) == 10 + end + + @testset "CircShiftedVector" begin + v = [1, 3, 5, 4] + v = opt_cu(v, use_cuda); + @test all(v .== CircShiftedVector(v)) + sv = CircShiftedVector(v, -1) + @test isequal(sv, CircShiftedVector(v, (-1,))) + @test length(sv) == 4 + @test all(sv .== opt_cu([3, 5, 4, 1], use_cuda)) + diff = v .- sv + @test diff == opt_cu([-2, -2, 1, 3], use_cuda) + @test shifts(sv) == (3,) + sv2 = CircShiftedVector(v, 1) + diff = v .- sv2 + @test copy(sv2) == opt_cu([4, 1, 3, 5], use_cuda) + @test all(CircShiftedVector(v, 1) .== circshift(v, 1)) + CUDA.@allowscalar sv[2] = 0 + @test collect(sv) == opt_cu([3, 0, 4, 1], use_cuda) + @test v == opt_cu([1, 3, 0, 4], use_cuda) + CUDA.@allowscalar sv[3] = 12 + @test collect(sv) == opt_cu([3, 0, 12, 1], use_cuda) + @test v == opt_cu([1, 3, 0, 12], use_cuda) + CUDA.@allowscalar @test sv === setindex!(sv, 12, 3) + @test checkbounds(Bool, sv, 2) + @test !checkbounds(Bool, sv, 123) + sv = CircShiftedArray(v, 3) + svnest = CircShiftedArray(CircShiftedArray(v, 2), 1) + @test sv === svnest + end + + @testset "CircShiftedArray" begin + v = reshape(1:16, 4, 4) + v = opt_cu(v, use_cuda); + @test all(v .== CircShiftedArray(v)) + sv = CircShiftedArray(v, (-2, 0)) + @test length(sv) == 16 + CUDA.@allowscalar @test sv[1, 3] == 11 + @test shifts(sv) == (2, 0) + @test isequal(sv, CircShiftedArray(v, -2)) + @test isequal(@inferred(CircShiftedArray(v, 2)), @inferred(CircShiftedArray(v, (2,)))) + @test isequal(@inferred(CircShiftedArray(v)), @inferred(CircShiftedArray(v, (0, 0)))) + s = CircShiftedArray(v, (0, 2)) + @test isequal(collect(s), opt_cu([ 9 13 1 5; + 10 14 2 6; + 11 15 3 7; + 12 16 4 8], use_cuda)) + sv = CircShiftedArray(v, 3) + svnest = CircShiftedArray(CircShiftedArray(v, 2), 1) + @test sv === svnest + end + + @testset "circshift" begin + v = reshape(1:16, 4, 4) + v = opt_cu(v, use_cuda); + @test all(circshift(v, (1, -1)) .== ShiftedArrays.circshift(v, (1, -1))) + @test all(circshift(v, (1,)) .== ShiftedArrays.circshift(v, (1,))) + @test all(circshift(v, 3) .== ShiftedArrays.circshift(v, 3)) + sv = ShiftedArrays.circshift(v, 3) + svnest = ShiftedArrays.circshift(ShiftedArrays.circshift(v, 2), 1) + @test sv === svnest + end + + @testset "fftshift and ifftshift" begin + function test_fftshift(x, dims=1:ndims(x)) + x = opt_cu(x, use_cuda) + @test fftshift(x, dims) == ShiftedArrays.fftshift(x, dims) + @test ifftshift(x, dims) == ShiftedArrays.ifftshift(x, dims) + end + + test_fftshift(randn((10,))) + test_fftshift(randn((11,))) + test_fftshift(randn((10,)), (1,)) + test_fftshift(randn(ComplexF32, (11,)), (1,)) + test_fftshift(randn((10, 11)), (1,)) + test_fftshift(randn((10, 11)), (2,)) + test_fftshift(randn(ComplexF32,(10, 11)), (1, 2)) + test_fftshift(randn((10, 11))) + + test_fftshift(randn((10, 11, 12, 13)), (2, 4)) + test_fftshift(randn((10, 11, 12, 13)), (5)) + test_fftshift(randn((10, 11, 12, 13))) + + @test (2, 2, 0) == ShiftedArrays.ft_center_diff((4, 5, 6), (1, 2)) # Fourier center is at (2, 3, 0) + @test (2, 2, 3) == ShiftedArrays.ft_center_diff((4, 5, 6), (1, 2, 3)) # Fourier center is at (2, 3, 4) + end + + @testset "laglead" begin + v = [1, 3, 8, 12] + v = opt_cu(v, use_cuda); + diff = v .- ShiftedArrays.lag(v) + @test isequal(diff, opt_cu([missing, 2, 5, 4], use_cuda)) + + diff2 = v .- ShiftedArrays.lag(v, 2) + @test isequal(diff2, opt_cu([missing, missing, 7, 9], use_cuda)) + + @test all(ShiftedArrays.lag(v, 2, default = -100) .== coalesce.(ShiftedArrays.lag(v, 2), -100)) + + diff = v .- ShiftedArrays.lead(v) + @test isequal(diff, opt_cu([-2, -5, -4, missing], use_cuda)) + + diff2 = v .- ShiftedArrays.lead(v, 2) + @test isequal(diff2, opt_cu([-7, -9, missing, missing], use_cuda)) + + @test all(ShiftedArrays.lead(v, 2, default = -100) .== coalesce.(ShiftedArrays.lead(v, 2), -100)) + + @test ShiftedArrays.lag(ShiftedArrays.lag(v, 1), 2) === ShiftedArrays.lag(v, 3) + @test ShiftedArrays.lead(ShiftedArrays.lead(v, 1), 2) === ShiftedArrays.lead(v, 3) + end end -@testset "fftshift and ifftshift" begin - function test_fftshift(x, dims=1:ndims(x)) - @test fftshift(x, dims) == ShiftedArrays.fftshift(x, dims) - @test ifftshift(x, dims) == ShiftedArrays.ifftshift(x, dims) +run_all_tests() + +if CUDA.functional() + @testset "all in CUDA" begin + CUDA.allowscalar(false); + run_all_tests(true) + + # some extra tests to check for indexing with integers + v = rand(10,11) + sv = ShiftedArray(cu(rand(10,11)), (3,4)) + @test_throws ErrorException sv[5,6] + @test (CUDA.@allowscalar sv[5,6]) == Array(sv)[5,6] + @test_throws ErrorException sv[47] + @test_throws BoundsError sv[1,2,3] + @test (CUDA.@allowscalar sv[47]) == Array(sv)[47] + cv = CircShiftedArray(cu(rand(10,11)), (3,4)) + @test_throws ErrorException cv[5,6] + @test (CUDA.@allowscalar cv[5,6]) == Array(cv)[5,6] + @test_throws ErrorException cv[47] + @test_throws BoundsError cv[1,2,3] + @test (CUDA.@allowscalar cv[47]) == Array(cv)[47] + end +else + @testset "no CUDA available!" begin + @test true == true end - - test_fftshift(randn((10,))) - test_fftshift(randn((11,))) - test_fftshift(randn((10,)), (1,)) - test_fftshift(randn(ComplexF32, (11,)), (1,)) - test_fftshift(randn((10, 11)), (1,)) - test_fftshift(randn((10, 11)), (2,)) - test_fftshift(randn(ComplexF32,(10, 11)), (1, 2)) - test_fftshift(randn((10, 11))) - - test_fftshift(randn((10, 11, 12, 13)), (2, 4)) - test_fftshift(randn((10, 11, 12, 13)), (5)) - test_fftshift(randn((10, 11, 12, 13))) - - @test (2, 2, 0) == ShiftedArrays.ft_center_diff((4, 5, 6), (1, 2)) # Fourier center is at (2, 3, 0) - @test (2, 2, 3) == ShiftedArrays.ft_center_diff((4, 5, 6), (1, 2, 3)) # Fourier center is at (2, 3, 4) end -@testset "laglead" begin - v = [1, 3, 8, 12] - diff = v .- ShiftedArrays.lag(v) - @test isequal(diff, [missing, 2, 5, 4]) - - diff2 = v .- ShiftedArrays.lag(v, 2) - @test isequal(diff2, [missing, missing, 7, 9]) - - @test all(ShiftedArrays.lag(v, 2, default = -100) .== coalesce.(ShiftedArrays.lag(v, 2), -100)) - - diff = v .- ShiftedArrays.lead(v) - @test isequal(diff, [-2, -5, -4, missing]) - - diff2 = v .- ShiftedArrays.lead(v, 2) - @test isequal(diff2, [-7, -9, missing, missing]) - - @test all(ShiftedArrays.lead(v, 2, default = -100) .== coalesce.(ShiftedArrays.lead(v, 2), -100)) - - @test ShiftedArrays.lag(ShiftedArrays.lag(v, 1), 2) === ShiftedArrays.lag(v, 3) - @test ShiftedArrays.lead(ShiftedArrays.lead(v, 1), 2) === ShiftedArrays.lead(v, 3) -end