diff --git a/src/memopt.jl b/src/memopt.jl index ccd3303..dacc336 100644 --- a/src/memopt.jl +++ b/src/memopt.jl @@ -90,6 +90,7 @@ function memopt(metadata_module::Module, is_parallel_kernel::Bool, caller::Modul loopstart = minimum(values(loopentrys)) loopend = loopsize use_any_shmem = any(values(use_shmems)) + shmem_optvars = tuple((A for A in optvars if use_shmems[A])...)::Tuple{Vararg{Symbol}} shmem_index_groups = define_shmem_index_groups(hx1s, hy1s, hx2s, hy2s, optvars, use_shmems, loopdim) shmem_vars = define_shmem_vars(oz_maxs, hx1s, hy1s, hx2s, hy2s, optvars, indices, use_shmems, use_shmem_xs, use_shmem_ys, shmem_index_groups, use_shmemhalos, use_shmemindices, loopdim) shmem_exprs = define_shmem_exprs(shmem_vars, loopdim) @@ -469,7 +470,7 @@ $(( # NOTE: the if statement is not needed here as we only deal with registers else @ArgumentError("memopt: only loopdim=3 is currently supported.") end - store_metadata(metadata_module, is_parallel_kernel, caller, offset_mins, offset_maxs, offsets, optvars, loopdim, loopsize, optranges, use_shmemhalos) + store_metadata(metadata_module, is_parallel_kernel, caller, offset_mins, offset_maxs, offsets, optvars, shmem_optvars, use_any_shmem, loopdim, loopsize, optranges, use_shmemhalos) # @show QuoteNode(ParallelKernel.simplify_varnames!(ParallelKernel.remove_linenumbernodes!(deepcopy(body)))) return body end @@ -1019,10 +1020,15 @@ function wrap_loop(index::Symbol, range::UnitRange, block::Expr; unroll=false) end end -function store_metadata(metadata_module::Module, is_parallel_kernel::Bool, caller::Module, offset_mins::Dict{Symbol, <:NTuple{3,Integer}}, offset_maxs::Dict{Symbol, <:NTuple{3,Integer}}, offsets::Dict{Symbol, Dict{Any, Any}}, optvars::NTuple{N,Symbol} where N, loopdim::Integer, loopsize::Integer, optranges::Dict{Any, Any}, use_shmemhalos) +function store_metadata(metadata_module::Module, is_parallel_kernel::Bool, caller::Module, offset_mins::Dict{Symbol, <:NTuple{3,Integer}}, offset_maxs::Dict{Symbol, <:NTuple{3,Integer}}, offsets::Dict{Symbol, Dict{Any, Any}}, optvars::NTuple{N,Symbol} where N, shmem_optvars::NTuple{M,Symbol} where M, use_any_shmem::Bool, loopdim::Integer, loopsize::Integer, optranges::Dict{Any, Any}, use_shmemhalos) memopt = true nonconst_metadata = get_nonconst_metadata(caller) stencilranges = NamedTuple(A => (offset_mins[A][1]:offset_maxs[A][1], offset_mins[A][2]:offset_maxs[A][2], offset_mins[A][3]:offset_maxs[A][3]) for A in optvars) + use_shmemhalos = NamedTuple(A => use_shmemhalos[A] for A in optvars) + loopsizes = (loopdim==3) ? (1, 1, loopsize) : (loopdim==2) ? (1, loopsize, 1) : (loopsize, 1, 1) + shmem_dim1 = (loopdim==3) ? 1 : (loopdim==2) ? 1 : 2 + shmem_dim2 = (loopdim==3) ? 2 : (loopdim==2) ? 3 : 3 + shmem_spans = NamedTuple(A => (length(stencilranges[A][shmem_dim1]) - 1, length(stencilranges[A][shmem_dim2]) - 1) for A in optvars) if nonconst_metadata storeexpr = quote is_parallel_kernel = $is_parallel_kernel @@ -1031,9 +1037,15 @@ function store_metadata(metadata_module::Module, is_parallel_kernel::Bool, calle stencilranges = $stencilranges offsets = $offsets optvars = $optvars + shmem_optvars = $shmem_optvars + shmem_spans = $shmem_spans loopdim = $loopdim loopsize = $loopsize + loopsizes = $loopsizes + shmem_dim1 = $shmem_dim1 + shmem_dim2 = $shmem_dim2 optranges = $optranges + use_any_shmem = $use_any_shmem use_shmemhalos = $use_shmemhalos end else @@ -1044,9 +1056,15 @@ function store_metadata(metadata_module::Module, is_parallel_kernel::Bool, calle const stencilranges = $stencilranges const offsets = $offsets const optvars = $optvars + const shmem_optvars = $shmem_optvars + const shmem_spans = $shmem_spans const loopdim = $loopdim const loopsize = $loopsize + const loopsizes = $loopsizes + const shmem_dim1 = $shmem_dim1 + const shmem_dim2 = $shmem_dim2 const optranges = $optranges + const use_any_shmem = $use_any_shmem const use_shmemhalos = $use_shmemhalos end end diff --git a/src/parallel.jl b/src/parallel.jl index e07e211..c8397dd 100644 --- a/src/parallel.jl +++ b/src/parallel.jl @@ -361,6 +361,38 @@ end ## @PARALLEL CALL FUNCTIONS +@generated function compute_memopt_shmem(::Val{optvars}, ::Val{use_shmemhalos}, ::Val{shmem_spans}, ::Val{shmem_dim1}, ::Val{shmem_dim2}, nthreads, ::Type{T}) where {optvars, use_shmemhalos, shmem_spans, shmem_dim1, shmem_dim2, T} + terms = [:( + (nthreads[$shmem_dim1] + $(getproperty(use_shmemhalos, A)) * $(getproperty(shmem_spans, A)[1])) * + (nthreads[$shmem_dim2] + $(getproperty(use_shmemhalos, A)) * $(getproperty(shmem_spans, A)[2])) * + sizeof(T) + ) for A in optvars] + if isempty(terms) + return :(0) + elseif length(terms) == 1 + return terms[1] + else + return Expr(:call, :+, terms...) + end +end + +@generated function compute_memopt_ranges(::Val{is_parallel_kernel}, ::Val{nb_parallel_indices}, ::Val{loopdim}, nthreads_x_max, nthreads_max_memopt, args...) where {is_parallel_kernel, nb_parallel_indices, loopdim} + if is_parallel_kernel + range_expr = :(ParallelStencil.get_ranges_memopt(nthreads_x_max, nthreads_max_memopt, $loopdim, args...)) + else + range_expr = :(ParallelStencil.ParallelKernel.get_ranges(args...)) + end + errorcall = :(ParallelStencil.@ArgumentError(ParallelStencil.ERRMSG_AUTOMATIC_RANGES_PARALLEL)) + return quote + nb_input_dims = ParallelStencil.get_nb_input_dims(args...) + nb_dims_match = (nb_input_dims == $nb_parallel_indices) + if nb_dims_match isa Bool + nb_dims_match || $errorcall + end + $range_expr + end +end + function parallel_call_memopt(caller::Module, ranges::Union{Symbol,Expr}, kernelcall::Expr, backend_kwargs_expr::Array, async::Bool; memopt::Bool=false, configcall::Expr=kernelcall) if haskey(backend_kwargs_expr, :shmem) @KeywordArgumentError("@parallel : keyword `shmem` is not allowed when memopt=true is set.") end package = get_package(caller) @@ -369,20 +401,25 @@ function parallel_call_memopt(caller::Module, ranges::Union{Symbol,Expr}, kernel configcall_kwarg_expr = :(configcall=$configcall) metadata_call = create_metadata_call(configcall) metadata_module = metadata_call + loopdim = :($(metadata_module).loopdim) + loopsizes = :($(metadata_module).loopsizes) stencilranges = :($(metadata_module).stencilranges) use_shmemhalos = :($(metadata_module).use_shmemhalos) - optvars = :($(metadata_module).optvars) - loopdim = :($(metadata_module).loopdim) - loopsize = :($(metadata_module).loopsize) - loopsizes = :(($loopdim==3) ? (1, 1, $loopsize) : ($loopdim==2) ? (1, $loopsize, 1) : ($loopsize, 1, 1)) + use_any_shmem = :($(metadata_module).use_any_shmem) + shmem_dim1 = :($(metadata_module).shmem_dim1) + shmem_dim2 = :($(metadata_module).shmem_dim2) + shmem_optvars = :($(metadata_module).shmem_optvars) + shmem_spans = :($(metadata_module).shmem_spans) maxsize = :(cld.(length.(ParallelStencil.ParallelKernel.promote_ranges($ranges)), $loopsizes)) nthreads = :( ParallelStencil.compute_nthreads_memopt($nthreads_x_max, $nthreads_max_memopt, $maxsize, $loopdim, $stencilranges) ) nblocks = :( ParallelStencil.ParallelKernel.compute_nblocks($maxsize, $nthreads) ) numbertype = get_numbertype(caller) # not :(eltype($(optvars)[1])) # TODO: see how to obtain number type properly for each array: the type of the call call arguments corresponding to the optimization variables should be checked - dim1 = :(($loopdim==3) ? 1 : ($loopdim==2) ? 1 : 2) # TODO: to be determined if that is what is desired for loopdim 1 and 2. - dim2 = :(($loopdim==3) ? 2 : ($loopdim==2) ? 3 : 3) # TODO: to be determined if that is what is desired for loopdim 1 and 2. - A = gensym("A") - shmem = :(sum(($nthreads[$dim1]+$use_shmemhalos[$A]*(length($(stencilranges)[$A][$dim1])-1))*($nthreads[$dim2]+$use_shmemhalos[$A]*(length($(stencilranges)[$A][$dim2])-1))*sizeof($numbertype) for $A in $optvars)) + if get_nonconst_metadata(caller) + A = gensym("A") + shmem = :($use_any_shmem ? sum(($nthreads[$shmem_dim1] + $use_shmemhalos[$A] * ($(shmem_spans)[$A][1])) * ($nthreads[$shmem_dim2] + $use_shmemhalos[$A] * ($(shmem_spans)[$A][2])) * sizeof($numbertype) for $A in $shmem_optvars) : 0) + else + shmem = :(ParallelStencil.compute_memopt_shmem(Val($shmem_optvars), Val($use_shmemhalos), Val($shmem_spans), Val($shmem_dim1), Val($shmem_dim2), $nthreads, $numbertype)) + end if (async) return :(@parallel_async memopt=false $configcall_kwarg_expr $ranges $nblocks $nthreads shmem=$shmem $(backend_kwargs_expr...) $kernelcall) #TODO: the package and numbertype will have to be passed here further once supported as kwargs else return :(@parallel memopt=false $configcall_kwarg_expr $ranges $nblocks $nthreads shmem=$shmem $(backend_kwargs_expr...) $kernelcall) #TODO: ... end @@ -396,7 +433,12 @@ function parallel_call_memopt(caller::Module, kernelcall::Expr, backend_kwargs_e metadata_module = metadata_call loopdim = :($(metadata_module).loopdim) is_parallel_kernel = :($(metadata_module).is_parallel_kernel) - ranges = add_nb_parallel_indices_check(:( ($is_parallel_kernel) ? ParallelStencil.get_ranges_memopt($nthreads_x_max, $nthreads_max_memopt, $loopdim, $(configcall.args[2:end]...)) : ParallelStencil.ParallelKernel.get_ranges($(configcall.args[2:end]...))), configcall) + if get_nonconst_metadata(caller) + ranges = add_nb_parallel_indices_check(:( ($is_parallel_kernel) ? ParallelStencil.get_ranges_memopt($nthreads_x_max, $nthreads_max_memopt, $loopdim, $(configcall.args[2:end]...)) : ParallelStencil.ParallelKernel.get_ranges($(configcall.args[2:end]...))), configcall) + else + nb_parallel_indices = :($(metadata_module).nb_parallel_indices) + ranges = :(ParallelStencil.compute_memopt_ranges(Val($is_parallel_kernel), Val($nb_parallel_indices), Val($loopdim), $nthreads_x_max, $nthreads_max_memopt, $(configcall.args[2:end]...))) + end parallel_call_memopt(caller, ranges, kernelcall, backend_kwargs_expr, async; memopt=memopt, configcall=configcall) end @@ -552,8 +594,10 @@ function create_metadata_function(kernel::Expr, metadata_module::Module) # NOTE: kernelname = get_name(kernel) functionname = get_meta_function(kernelname) metadata_function = set_name(metadata_function, functionname) - set_body!(metadata_function, :(return $metadata_module)) - return metadata_function + set_body!(metadata_function, quote + return $metadata_module + end) + return :(@inline $metadata_function) end function create_metadata_call(configcall::Expr) diff --git a/test/test_FiniteDifferences1D.jl b/test/test_FiniteDifferences1D.jl index 4e1cd5a..2c95dca 100644 --- a/test/test_FiniteDifferences1D.jl +++ b/test/test_FiniteDifferences1D.jl @@ -36,7 +36,7 @@ eval(:( $(interpolate(:__padding__, (false, package!=PKG_POLYESTER), :( #TODO: this needs to be restored to (false, true) when Polyester supports padding. @testset "(padding=$__padding__)" begin @require !@is_initialized() - @init_parallel_stencil($package, $FloatDefault, 1, padding=__padding__) + @init_parallel_stencil($package, $FloatDefault, 1, padding=__padding__, nonconst_metadata=true) @require @is_initialized() nx = (9,) A = @IField(nx, @rand); diff --git a/test/test_FiniteDifferences2D.jl b/test/test_FiniteDifferences2D.jl index 17cf848..6b0a6c5 100644 --- a/test/test_FiniteDifferences2D.jl +++ b/test/test_FiniteDifferences2D.jl @@ -36,7 +36,7 @@ eval(:( $(interpolate(:__padding__, (false, package!=PKG_POLYESTER), :( #TODO: this needs to be restored to (false, true) when Polyester supports padding. @testset "(padding=$__padding__)" begin @require !@is_initialized() - @init_parallel_stencil($package, $FloatDefault, 2, padding=__padding__) + @init_parallel_stencil($package, $FloatDefault, 2, padding=__padding__, nonconst_metadata=true) @require @is_initialized() nxy = (9, 7) A = @IField(nxy, @rand); diff --git a/test/test_FiniteDifferences3D.jl b/test/test_FiniteDifferences3D.jl index eb76dc0..88013cf 100644 --- a/test/test_FiniteDifferences3D.jl +++ b/test/test_FiniteDifferences3D.jl @@ -36,7 +36,7 @@ eval(:( $(interpolate(:__padding__, (false, package!=PKG_POLYESTER), :( #TODO: this needs to be restored to (false, true) when Polyester supports padding. @testset "(padding=$__padding__)" begin @require !@is_initialized() - @init_parallel_stencil($package, $FloatDefault, 3, padding=__padding__) + @init_parallel_stencil($package, $FloatDefault, 3, padding=__padding__, nonconst_metadata=true) @require @is_initialized() nxyz = (9, 7, 8) A = @IField(nxyz, @rand) diff --git a/test/test_kernel_language.jl b/test/test_kernel_language.jl index 3bae9a9..72e0701 100644 --- a/test/test_kernel_language.jl +++ b/test/test_kernel_language.jl @@ -38,7 +38,7 @@ Base.retry_load_extensions() eval(:( @testset "$(basename(@__FILE__)) (package: $(nameof($package)))" begin @require !@is_initialized() - @init_parallel_stencil($package, $FloatDefault, 3) + @init_parallel_stencil($package, $FloatDefault, 3, nonconst_metadata=true) @require @is_initialized() @testset "Pass-through macro mapping" begin diff --git a/test/test_parallel.jl b/test/test_parallel.jl index eb231c0..f2b278a 100644 --- a/test/test_parallel.jl +++ b/test/test_parallel.jl @@ -1011,7 +1011,7 @@ eval(:( @testset "3. parallel (with Fields)" begin @static if $package != $PKG_POLYESTER # TODO: this needs to be removed once Polyester supports padding @require !@is_initialized() - @init_parallel_stencil($package, $FloatDefault, 3, padding=true) + @init_parallel_stencil($package, $FloatDefault, 3, padding=true, nonconst_metadata=true) @require @is_initialized() @testset "padding" begin @testset "@parallel (3D, @all)" begin @@ -1063,7 +1063,7 @@ eval(:( @testset "4. global defaults" begin @testset "inbounds=true" begin @require !@is_initialized() - @init_parallel_stencil($package, $FloatDefault, 1, inbounds=true) + @init_parallel_stencil($package, $FloatDefault, 1, inbounds=true, nonconst_metadata=true) @require @is_initialized expansion = @prettystring(1, @parallel_indices (ix) inbounds=true f(A) = (2*A; return)) @test occursin("Base.@inbounds begin", expansion) @@ -1076,7 +1076,7 @@ eval(:( @testset "padding=true" begin @static if $package != $PKG_POLYESTER # TODO: this needs to be removed once Polyester supports padding @require !@is_initialized() - @init_parallel_stencil($package, $FloatDefault, 3, padding=true) + @init_parallel_stencil($package, $FloatDefault, 3, padding=true, nonconst_metadata=true) @require @is_initialized @testset "apply masks | handling padding (padding=true (globally))" begin expansion = @prettystring(1, @parallel sum!(A, B) = (@all(A) = @all(A) + @all(B); return)) @@ -1097,7 +1097,7 @@ eval(:( end; @testset "@parallel_indices (I...) (1D)" begin @require !@is_initialized() - @init_parallel_stencil($package, $FloatDefault, 1) + @init_parallel_stencil($package, $FloatDefault, 1, nonconst_metadata=true) @require @is_initialized A = @zeros(4*5*6) one = $FloatDefault(1) @@ -1111,7 +1111,7 @@ eval(:( end; @testset "@parallel_indices (I...) (2D)" begin @require !@is_initialized() - @init_parallel_stencil($package, $FloatDefault, 2) + @init_parallel_stencil($package, $FloatDefault, 2, nonconst_metadata=true) @require @is_initialized A = @zeros(4, 5*6) one = $FloatDefault(1) @@ -1125,7 +1125,7 @@ eval(:( end; @testset "@parallel_indices (I...) (3D)" begin @require !@is_initialized() - @init_parallel_stencil($package, $FloatDefault, 3) + @init_parallel_stencil($package, $FloatDefault, 3, nonconst_metadata=true) @require @is_initialized A = @zeros(4, 5, 6) one = $FloatDefault(1) @@ -1140,7 +1140,7 @@ eval(:( end; @testset "5. parallel macros (numbertype and ndims ommited)" begin @require !@is_initialized() - @init_parallel_stencil(package = $package) + @init_parallel_stencil(package = $package, nonconst_metadata=true) @require @is_initialized $(interpolate(:__T__, ARRAYTYPES, :( @testset "Data.__T__{T} to Data.Device.__T__{T}" begin @@ -1263,18 +1263,53 @@ eval(:( metadata_symbols = sort(setdiff(names(metadata; all=true), names(metadata))) @test metadata isa Module @test length(names(metadata)) == 1 - @test metadata_symbols == [:is_parallel_kernel, :loopdim, :loopsize, :memopt, :nb_parallel_indices, :nonconst_metadata, :offsets, :optranges, :optvars, :stencilranges, :use_shmemhalos] + @test metadata_symbols == [:is_parallel_kernel, :loopdim, :loopsize, :loopsizes, :memopt, :nb_parallel_indices, :nonconst_metadata, :offsets, :optranges, :optvars, :shmem_dim1, :shmem_dim2, :shmem_optvars, :shmem_spans, :stencilranges, :use_any_shmem, :use_shmemhalos] @test metadata.is_parallel_kernel == false @test metadata.loopdim == 3 @test metadata.loopsize == 3 + @test metadata.loopsizes == (1, 1, 3) @test metadata.memopt == true @test metadata.nb_parallel_indices == 3 @test metadata.nonconst_metadata == true @test metadata.offsets[:B][(0, 0)][0] == 1 @test metadata.optranges[:B] == (0:0, 0:0, 0:0) @test metadata.optvars == (:B,) + @test metadata.shmem_dim1 == 1 + @test metadata.shmem_dim2 == 2 + @test metadata.shmem_optvars == () + @test metadata.shmem_optvars isa NTuple{0,Symbol} + @test metadata.shmem_spans == (B = (0, 0),) @test metadata.stencilranges == (B = (0:0, 0:0, 0:0),) + @test metadata.use_any_shmem == false @test metadata.use_shmemhalos[:B] == true + @parallel_indices (ix, iy, iz) memopt=true loopsize=3 optvars=B optranges=(B=(0:0,0:0,-1:1),) function metadata_memopt_zstencil_probe!(A, B, D) + A[ix, iy, iz] = B[ix, iy, iz-1] + B[ix, iy, iz] + B[ix, iy, iz+1] + D[ix, iy, iz, 1] + return + end + metadata_z = @metadata metadata_memopt_zstencil_probe!(A, B, D) + @test metadata_z.shmem_optvars == () + @test metadata_z.shmem_optvars isa NTuple{0,Symbol} + @test metadata_z.shmem_spans == (B = (0, 0),) + @test metadata_z.stencilranges == (B = (0:0, 0:0, -1:1),) + @test metadata_z.use_any_shmem == false + @parallel_indices (ix, iy, iz) memopt=true loopsize=3 optvars=B function metadata_memopt_fullstencil_probe!(A, B, D) + A[ix, iy, iz] = B[ix-1, iy-1, iz-1] + B[ix, iy, iz] + B[ix+1, iy+1, iz+1] + D[ix, iy, iz, 1] + return + end + metadata_full = @metadata metadata_memopt_fullstencil_probe!(A, B, D) + @static if @isgpu($package) + @test metadata_full.shmem_optvars == (:B,) + @test metadata_full.shmem_optvars isa NTuple{1,Symbol} + @test metadata_full.shmem_spans == (B = (2, 2),) + @test metadata_full.stencilranges == (B = (-1:1, -1:1, -1:1),) + @test metadata_full.use_any_shmem == true + else + @test metadata_full.shmem_optvars == () + @test metadata_full.shmem_optvars isa NTuple{0,Symbol} + @test metadata_full.shmem_spans == (B = (0, 0),) + @test metadata_full.stencilranges == (B = (0:0, 0:0, 0:0),) + @test metadata_full.use_any_shmem == false + end @test all(Array(A) .== 0) end; end; @@ -1360,10 +1395,4 @@ eval(:( end; )) -eval(:( - @testset "$(basename(@__FILE__)) metadata (package: $(nameof($package)))" begin - - end; -)) - end == nothing || true;