From 60364adae4d2f6964360bc5101a2386314c88f47 Mon Sep 17 00:00:00 2001 From: Samuel Omlin Date: Mon, 16 Mar 2026 17:34:11 +0100 Subject: [PATCH 1/7] Add functions to compute shared memory size and ranges for memopt --- src/parallel.jl | 67 ++++++++++++++++++++++++++----------------------- 1 file changed, 35 insertions(+), 32 deletions(-) diff --git a/src/parallel.jl b/src/parallel.jl index c8397dd..911ced3 100644 --- a/src/parallel.jl +++ b/src/parallel.jl @@ -361,38 +361,6 @@ end ## @PARALLEL CALL FUNCTIONS -@generated function compute_memopt_shmem(::Val{optvars}, ::Val{use_shmemhalos}, ::Val{shmem_spans}, ::Val{shmem_dim1}, ::Val{shmem_dim2}, nthreads, ::Type{T}) where {optvars, use_shmemhalos, shmem_spans, shmem_dim1, shmem_dim2, T} - terms = [:( - (nthreads[$shmem_dim1] + $(getproperty(use_shmemhalos, A)) * $(getproperty(shmem_spans, A)[1])) * - (nthreads[$shmem_dim2] + $(getproperty(use_shmemhalos, A)) * $(getproperty(shmem_spans, A)[2])) * - sizeof(T) - ) for A in optvars] - if isempty(terms) - return :(0) - elseif length(terms) == 1 - return terms[1] - else - return Expr(:call, :+, terms...) - end -end - -@generated function compute_memopt_ranges(::Val{is_parallel_kernel}, ::Val{nb_parallel_indices}, ::Val{loopdim}, nthreads_x_max, nthreads_max_memopt, args...) where {is_parallel_kernel, nb_parallel_indices, loopdim} - if is_parallel_kernel - range_expr = :(ParallelStencil.get_ranges_memopt(nthreads_x_max, nthreads_max_memopt, $loopdim, args...)) - else - range_expr = :(ParallelStencil.ParallelKernel.get_ranges(args...)) - end - errorcall = :(ParallelStencil.@ArgumentError(ParallelStencil.ERRMSG_AUTOMATIC_RANGES_PARALLEL)) - return quote - nb_input_dims = ParallelStencil.get_nb_input_dims(args...) - nb_dims_match = (nb_input_dims == $nb_parallel_indices) - if nb_dims_match isa Bool - nb_dims_match || $errorcall - end - $range_expr - end -end - function parallel_call_memopt(caller::Module, ranges::Union{Symbol,Expr}, kernelcall::Expr, backend_kwargs_expr::Array, async::Bool; memopt::Bool=false, configcall::Expr=kernelcall) if haskey(backend_kwargs_expr, :shmem) @KeywordArgumentError("@parallel : keyword `shmem` is not allowed when memopt=true is set.") end package = get_package(caller) @@ -497,6 +465,41 @@ function get_ranges_memopt(nthreads_x_max, nthreads_max_memopt, loopdim, args... end +## FUNCTIONS TO COMPUTE SHARED MEMORY SIZE AND RANGES FOR MEMOPT + +@generated function compute_memopt_shmem(::Val{optvars}, ::Val{use_shmemhalos}, ::Val{shmem_spans}, ::Val{shmem_dim1}, ::Val{shmem_dim2}, nthreads, ::Type{T}) where {optvars, use_shmemhalos, shmem_spans, shmem_dim1, shmem_dim2, T} + terms = [:( + (nthreads[$shmem_dim1] + $(getproperty(use_shmemhalos, A)) * $(getproperty(shmem_spans, A)[1])) * + (nthreads[$shmem_dim2] + $(getproperty(use_shmemhalos, A)) * $(getproperty(shmem_spans, A)[2])) * + sizeof(T) + ) for A in optvars] + if isempty(terms) + return :(0) + elseif length(terms) == 1 + return terms[1] + else + return Expr(:call, :+, terms...) + end +end + +@generated function compute_memopt_ranges(::Val{is_parallel_kernel}, ::Val{nb_parallel_indices}, ::Val{loopdim}, nthreads_x_max, nthreads_max_memopt, args...) where {is_parallel_kernel, nb_parallel_indices, loopdim} + if is_parallel_kernel + range_expr = :(ParallelStencil.get_ranges_memopt(nthreads_x_max, nthreads_max_memopt, $loopdim, args...)) + else + range_expr = :(ParallelStencil.ParallelKernel.get_ranges(args...)) + end + errorcall = :(ParallelStencil.@ArgumentError(ParallelStencil.ERRMSG_AUTOMATIC_RANGES_PARALLEL)) + return quote + nb_input_dims = ParallelStencil.get_nb_input_dims(args...) + nb_dims_match = (nb_input_dims == $nb_parallel_indices) + if nb_dims_match isa Bool + nb_dims_match || $errorcall + end + $range_expr + end +end + + ## FUNCTIONS TO DEAL WITH MASKS (@WITHIN) AND INDICES is_splatarg(x) = isa(x,Expr) && (x.head == :...) From 299f7054e84c7cc37b8c198046629ce50d21941d Mon Sep 17 00:00:00 2001 From: Samuel Omlin Date: Tue, 17 Mar 2026 09:03:53 +0100 Subject: [PATCH 2/7] Refactor parallel_call_memopt to utilize metadata for memory optimization and range computation --- src/parallel.jl | 101 +++++++++++++++++++++++++++++++----------------- 1 file changed, 66 insertions(+), 35 deletions(-) diff --git a/src/parallel.jl b/src/parallel.jl index 911ced3..f384f11 100644 --- a/src/parallel.jl +++ b/src/parallel.jl @@ -191,7 +191,7 @@ function parallel(source::LineNumberNode, caller::Module, args::Union{Symbol,Exp parallel_call_memopt(caller, posargs..., kernelarg, backend_kwargs_expr, async; kwargs...) else if isempty(posargs) - ranges = add_nb_parallel_indices_check(:(ParallelStencil.ParallelKernel.get_ranges($(configcall.args[2:end]...))), configcall) + ranges = :(ParallelStencil.compute_parallel_ranges(Val(($(create_metadata_call(configcall))).nb_parallel_indices), $(configcall.args[2:end]...))) ParallelKernel.parallel(caller, ranges, backend_kwargs_expr..., configcall_kwarg_expr, kernelarg; package=package, async=async) else ParallelKernel.parallel(caller, posargs..., backend_kwargs_expr..., configcall_kwarg_expr, kernelarg; package=package, async=async) @@ -361,35 +361,44 @@ end ## @PARALLEL CALL FUNCTIONS -function parallel_call_memopt(caller::Module, ranges::Union{Symbol,Expr}, kernelcall::Expr, backend_kwargs_expr::Array, async::Bool; memopt::Bool=false, configcall::Expr=kernelcall) +function parallel_call_memopt(caller::Module, metadata_expr::Union{Symbol,Expr}, ranges::Union{Symbol,Expr}, kernelcall::Expr, backend_kwargs_expr::Array, async::Bool; memopt::Bool=false, configcall::Expr=kernelcall) if haskey(backend_kwargs_expr, :shmem) @KeywordArgumentError("@parallel : keyword `shmem` is not allowed when memopt=true is set.") end package = get_package(caller) nthreads_x_max = ParallelKernel.determine_nthreads_x_max(package) nthreads_max_memopt = determine_nthreads_max_memopt(package) configcall_kwarg_expr = :(configcall=$configcall) - metadata_call = create_metadata_call(configcall) - metadata_module = metadata_call - loopdim = :($(metadata_module).loopdim) - loopsizes = :($(metadata_module).loopsizes) - stencilranges = :($(metadata_module).stencilranges) - use_shmemhalos = :($(metadata_module).use_shmemhalos) - use_any_shmem = :($(metadata_module).use_any_shmem) - shmem_dim1 = :($(metadata_module).shmem_dim1) - shmem_dim2 = :($(metadata_module).shmem_dim2) - shmem_optvars = :($(metadata_module).shmem_optvars) - shmem_spans = :($(metadata_module).shmem_spans) - maxsize = :(cld.(length.(ParallelStencil.ParallelKernel.promote_ranges($ranges)), $loopsizes)) - nthreads = :( ParallelStencil.compute_nthreads_memopt($nthreads_x_max, $nthreads_max_memopt, $maxsize, $loopdim, $stencilranges) ) - nblocks = :( ParallelStencil.ParallelKernel.compute_nblocks($maxsize, $nthreads) ) numbertype = get_numbertype(caller) # not :(eltype($(optvars)[1])) # TODO: see how to obtain number type properly for each array: the type of the call call arguments corresponding to the optimization variables should be checked - if get_nonconst_metadata(caller) - A = gensym("A") - shmem = :($use_any_shmem ? sum(($nthreads[$shmem_dim1] + $use_shmemhalos[$A] * ($(shmem_spans)[$A][1])) * ($nthreads[$shmem_dim2] + $use_shmemhalos[$A] * ($(shmem_spans)[$A][2])) * sizeof($numbertype) for $A in $shmem_optvars) : 0) + launch_config_var = gensym("launch_config") + nblocks_var = gensym("nblocks") + nthreads_var = gensym("nthreads") + shmem_var = gensym("shmem") + launch_config_expr = :(ParallelStencil.compute_memopt_launch_config(Val($metadata_expr.loopsizes), Val($metadata_expr.loopdim), Val($metadata_expr.stencilranges), $nthreads_x_max, $nthreads_max_memopt, $ranges)) + shmem_expr = :(ParallelStencil.compute_memopt_shmem(Val($metadata_expr.shmem_optvars), Val($metadata_expr.use_shmemhalos), Val($metadata_expr.shmem_spans), Val($metadata_expr.shmem_dim1), Val($metadata_expr.shmem_dim2), $nthreads_var, $numbertype)) + if async + return quote + local $launch_config_var = $launch_config_expr + local $nblocks_var = $launch_config_var[1] + local $nthreads_var = $launch_config_var[2] + local $shmem_var = $shmem_expr + @parallel_async memopt=false $configcall_kwarg_expr $ranges $nblocks_var $nthreads_var shmem=$shmem_var $(backend_kwargs_expr...) $kernelcall + end else - shmem = :(ParallelStencil.compute_memopt_shmem(Val($shmem_optvars), Val($use_shmemhalos), Val($shmem_spans), Val($shmem_dim1), Val($shmem_dim2), $nthreads, $numbertype)) + return quote + local $launch_config_var = $launch_config_expr + local $nblocks_var = $launch_config_var[1] + local $nthreads_var = $launch_config_var[2] + local $shmem_var = $shmem_expr + @parallel memopt=false $configcall_kwarg_expr $ranges $nblocks_var $nthreads_var shmem=$shmem_var $(backend_kwargs_expr...) $kernelcall + end end - if (async) return :(@parallel_async memopt=false $configcall_kwarg_expr $ranges $nblocks $nthreads shmem=$shmem $(backend_kwargs_expr...) $kernelcall) #TODO: the package and numbertype will have to be passed here further once supported as kwargs - else return :(@parallel memopt=false $configcall_kwarg_expr $ranges $nblocks $nthreads shmem=$shmem $(backend_kwargs_expr...) $kernelcall) #TODO: ... +end + +function parallel_call_memopt(caller::Module, ranges::Union{Symbol,Expr}, kernelcall::Expr, backend_kwargs_expr::Array, async::Bool; memopt::Bool=false, configcall::Expr=kernelcall) + metadata_call = create_metadata_call(configcall) + metadata_var = gensym("metadata") + quote + local $metadata_var = $metadata_call + $(parallel_call_memopt(caller, metadata_var, ranges, kernelcall, backend_kwargs_expr, async; memopt=memopt, configcall=configcall)) end end @@ -398,16 +407,13 @@ function parallel_call_memopt(caller::Module, kernelcall::Expr, backend_kwargs_e nthreads_x_max = ParallelKernel.determine_nthreads_x_max(package) nthreads_max_memopt = determine_nthreads_max_memopt(package) metadata_call = create_metadata_call(configcall) - metadata_module = metadata_call - loopdim = :($(metadata_module).loopdim) - is_parallel_kernel = :($(metadata_module).is_parallel_kernel) - if get_nonconst_metadata(caller) - ranges = add_nb_parallel_indices_check(:( ($is_parallel_kernel) ? ParallelStencil.get_ranges_memopt($nthreads_x_max, $nthreads_max_memopt, $loopdim, $(configcall.args[2:end]...)) : ParallelStencil.ParallelKernel.get_ranges($(configcall.args[2:end]...))), configcall) - else - nb_parallel_indices = :($(metadata_module).nb_parallel_indices) - ranges = :(ParallelStencil.compute_memopt_ranges(Val($is_parallel_kernel), Val($nb_parallel_indices), Val($loopdim), $nthreads_x_max, $nthreads_max_memopt, $(configcall.args[2:end]...))) + metadata_var = gensym("metadata") + ranges_var = gensym("ranges") + quote + local $metadata_var = $metadata_call + local $ranges_var = ParallelStencil.compute_memopt_ranges(Val($metadata_var.is_parallel_kernel), Val($metadata_var.nb_parallel_indices), Val($metadata_var.loopdim), $nthreads_x_max, $nthreads_max_memopt, $(configcall.args[2:end]...)) + $(parallel_call_memopt(caller, metadata_var, ranges_var, kernelcall, backend_kwargs_expr, async; memopt=memopt, configcall=configcall)) end - parallel_call_memopt(caller, ranges, kernelcall, backend_kwargs_expr, async; memopt=memopt, configcall=configcall) end @@ -482,6 +488,34 @@ end end end +@generated function compute_memopt_launch_config(::Val{loopsizes}, ::Val{loopdim}, ::Val{stencilranges}, nthreads_x_max, nthreads_max_memopt, ranges) where {loopsizes, loopdim, stencilranges} + return quote + maxsize = cld.(length.(ParallelStencil.ParallelKernel.promote_ranges(ranges)), $loopsizes) + nthreads = ParallelStencil.compute_nthreads_memopt(nthreads_x_max, nthreads_max_memopt, maxsize, $loopdim, $stencilranges) + nblocks = ParallelStencil.ParallelKernel.compute_nblocks(maxsize, nthreads) + (nblocks, nthreads) + end +end + +@generated function check_nb_parallel_indices(::Val{nb_parallel_indices}, args...) where {nb_parallel_indices} + errorcall = :(ParallelStencil.@ArgumentError(ParallelStencil.ERRMSG_AUTOMATIC_RANGES_PARALLEL)) + return quote + nb_input_dims = ParallelStencil.get_nb_input_dims(args...) + nb_dims_match = (nb_input_dims == $nb_parallel_indices) + if nb_dims_match isa Bool + nb_dims_match || $errorcall + end + nothing + end +end + +@generated function compute_parallel_ranges(::Val{nb_parallel_indices}, args...) where {nb_parallel_indices} + return quote + ParallelStencil.check_nb_parallel_indices(Val($nb_parallel_indices), args...) + ParallelStencil.ParallelKernel.get_ranges(args...) + end +end + @generated function compute_memopt_ranges(::Val{is_parallel_kernel}, ::Val{nb_parallel_indices}, ::Val{loopdim}, nthreads_x_max, nthreads_max_memopt, args...) where {is_parallel_kernel, nb_parallel_indices, loopdim} if is_parallel_kernel range_expr = :(ParallelStencil.get_ranges_memopt(nthreads_x_max, nthreads_max_memopt, $loopdim, args...)) @@ -708,10 +742,7 @@ end function add_nb_parallel_indices_check(ranges::Union{Symbol,Expr}, configcall::Expr) metadata_call = create_metadata_call(configcall) - nb_parallel_indices = :(($metadata_call).nb_parallel_indices) - nb_input_dims = :(ParallelStencil.get_nb_input_dims($(configcall.args[2:end]...))) - errorcall = :(ParallelStencil.@ArgumentError(ParallelStencil.ERRMSG_AUTOMATIC_RANGES_PARALLEL)) - return :(($nb_input_dims != $nb_parallel_indices && $errorcall; $ranges)) + return :(ParallelStencil.check_nb_parallel_indices(Val(($metadata_call).nb_parallel_indices), $(configcall.args[2:end]...)); $ranges) end get_nb_input_dims(args...) = maximum((get_nb_input_dims(arg) for arg in args); init=1) From 2fce91585602e1e649f02582ca77cdc1034faf04 Mon Sep 17 00:00:00 2001 From: Samuel Omlin Date: Tue, 17 Mar 2026 09:04:02 +0100 Subject: [PATCH 3/7] Update tests in test_parallel to use compute_parallel_ranges and new memory optimization functions --- test/test_parallel.jl | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/test/test_parallel.jl b/test/test_parallel.jl index f2b278a..cde26de 100644 --- a/test/test_parallel.jl +++ b/test/test_parallel.jl @@ -59,7 +59,7 @@ eval(:( @static if $package == $PKG_CUDA call = @prettystring(1, @parallel f(A)) @test occursin("CUDA.@cuda", call) - @test occursin("ParallelStencil.ParallelKernel.get_ranges(A)", call) + @test occursin("ParallelStencil.compute_parallel_ranges(Val", call) @test occursin("nb_parallel_indices", call) @test occursin("CUDA.synchronize(CUDA.stream(); blocking = true)", call) call = @prettystring(1, @parallel ranges f(A)) @@ -73,11 +73,14 @@ eval(:( call = @prettystring(2, @parallel memopt=true f(A)) # @test occursin("CUDA.@cuda blocks = ParallelStencil.ParallelKernel.compute_nblocks(cld.(length.(ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A))),", call) # NOTE: now it is a very long multi line expression; before it continued as follows: (1, 1, 16)), ParallelStencil.compute_nthreads_memopt(cld.(length.(ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A))), (1, 1, 16)), 3, (-1:1, -1:1, -1:1))) threads = ParallelStencil.compute_nthreads_memopt(cld.(length.(ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A))), (1, 1, 16)), 3, (-1:1, -1:1, -1:1)) stream = CUDA.stream() shmem = ((ParallelStencil.compute_nthreads_memopt(cld.(length.(ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A))), (1, 1, 16)), 3, (-1:1, -1:1, -1:1)))[1] + 3) * ((ParallelStencil.compute_nthreads_memopt(cld.(length.(ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A))), (1, 1, 16)), 3, (-1:1, -1:1, -1:1)))[2] + 3) * sizeof(Float64) f(A, ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[1])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[2])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[3])))", call) call = @prettystring(2, @parallel ranges memopt=true f(A)) - @test occursin("CUDA.@cuda blocks = ParallelStencil.ParallelKernel.compute_nblocks(cld.(length.(ParallelStencil.ParallelKernel.promote_ranges(ranges)),", call) # NOTE: now it is a very long multi line expression; before it continued as follows: (1, 1, 16)), ParallelStencil.compute_nthreads_memopt(cld.(length.(ParallelStencil.ParallelKernel.promote_ranges(ranges)), (1, 1, 16)), 3, (-1:1, -1:1, -1:1))) threads = ParallelStencil.compute_nthreads_memopt(cld.(length.(ParallelStencil.ParallelKernel.promote_ranges(ranges)), (1, 1, 16)), 3, (-1:1, -1:1, -1:1)) stream = CUDA.stream() shmem = ((ParallelStencil.compute_nthreads_memopt(cld.(length.(ParallelStencil.ParallelKernel.promote_ranges(ranges)), (1, 1, 16)), 3, (-1:1, -1:1, -1:1)))[1] + 3) * ((ParallelStencil.compute_nthreads_memopt(cld.(length.(ParallelStencil.ParallelKernel.promote_ranges(ranges)), (1, 1, 16)), 3, (-1:1, -1:1, -1:1)))[2] + 3) * sizeof(Float64) f(A, ParallelStencil.ParallelKernel.promote_ranges(ranges), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[1])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[2])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[3])))", call) + @test occursin("ParallelStencil.compute_memopt_launch_config(Val", call) + @test occursin("ParallelStencil.compute_memopt_shmem(Val", call) + @test occursin("CUDA.@cuda blocks = var\"##nblocks", call) + @test occursin("f(A, ParallelStencil.ParallelKernel.promote_ranges(ranges)", call) elseif $package == $PKG_AMDGPU call = @prettystring(1, @parallel f(A)) @test occursin("AMDGPU.@roc", call) - @test occursin("ParallelStencil.ParallelKernel.get_ranges(A)", call) + @test occursin("ParallelStencil.compute_parallel_ranges(Val", call) @test occursin("nb_parallel_indices", call) @test occursin("AMDGPU.synchronize(AMDGPU.stream(); blocking = true)", call) call = @prettystring(1, @parallel ranges f(A)) @@ -91,14 +94,17 @@ eval(:( call = @prettystring(2, @parallel memopt=true f(A)) # @test occursin("AMDGPU.@roc gridsize = ParallelStencil.ParallelKernel.compute_nblocks(cld.(length.(ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A))),", call) # NOTE: now it is a very long multi line expression; before it continued as follows: (1, 1, 16)), ParallelStencil.compute_nthreads_memopt(cld.(length.(ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A))), (1, 1, 16)), 3, (-1:1, -1:1, -1:1))) groupsize = ParallelStencil.compute_nthreads_memopt(cld.(length.(ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A))), (1, 1, 16)), 3, (-1:1, -1:1, -1:1)) stream = AMDGPU.stream() shmem = ((ParallelStencil.compute_nthreads_memopt(cld.(length.(ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A))), (1, 1, 16)), 3, (-1:1, -1:1, -1:1)))[1] + 3) * ((ParallelStencil.compute_nthreads_memopt(cld.(length.(ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A))), (1, 1, 16)), 3, (-1:1, -1:1, -1:1)))[2] + 3) * sizeof(Float64) f(A, ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[1])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[2])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[3])))", call) call = @prettystring(2, @parallel ranges memopt=true f(A)) - @test occursin("AMDGPU.@roc gridsize = ParallelStencil.ParallelKernel.compute_nblocks(cld.(length.(ParallelStencil.ParallelKernel.promote_ranges(ranges)),", call) # NOTE: now it is a very long multi line expression; before it continued as follows: (1, 1, 16)), ParallelStencil.compute_nthreads_memopt(cld.(length.(ParallelStencil.ParallelKernel.promote_ranges(ranges)), (1, 1, 16)), 3, (-1:1, -1:1, -1:1))) groupsize = ParallelStencil.compute_nthreads_memopt(cld.(length.(ParallelStencil.ParallelKernel.promote_ranges(ranges)), (1, 1, 16)), 3, (-1:1, -1:1, -1:1)) stream = AMDGPU.stream() shmem = ((ParallelStencil.compute_nthreads_memopt(cld.(length.(ParallelStencil.ParallelKernel.promote_ranges(ranges)), (1, 1, 16)), 3, (-1:1, -1:1, -1:1)))[1] + 3) * ((ParallelStencil.compute_nthreads_memopt(cld.(length.(ParallelStencil.ParallelKernel.promote_ranges(ranges)), (1, 1, 16)), 3, (-1:1, -1:1, -1:1)))[2] + 3) * sizeof(Float64) f(A, ParallelStencil.ParallelKernel.promote_ranges(ranges), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[1])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[2])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[3])))", call) + @test occursin("ParallelStencil.compute_memopt_launch_config(Val", call) + @test occursin("ParallelStencil.compute_memopt_shmem(Val", call) + @test occursin("AMDGPU.@roc gridsize = var\"##nblocks", call) + @test occursin("f(A, ParallelStencil.ParallelKernel.promote_ranges(ranges)", call) elseif $package == $PKG_KERNELABSTRACTIONS call = @prettystring(1, @parallel f(A)) @test occursin("ParallelStencil.ParallelKernel.@ka", call) @test occursin("handle(ParallelStencil.ParallelKernel.current_hardware(@__MODULE__()), :$PKG_KERNELABSTRACTIONS)", call) @test occursin("compute_nblocks", call) @test occursin("compute_nthreads", call) - @test occursin("ParallelStencil.ParallelKernel.get_ranges(A)", call) + @test occursin("ParallelStencil.compute_parallel_ranges(Val", call) @test occursin("nb_parallel_indices", call) @test !occursin("CUDA.@cuda", call) @test !occursin("AMDGPU.@roc", call) @@ -121,17 +127,20 @@ eval(:( elseif @iscpu($package) call = @prettystring(1, @parallel f(A)) @test occursin("f(A, ParallelStencil.ParallelKernel.promote_ranges(", call) - @test occursin("ParallelStencil.ParallelKernel.get_ranges(A)", call) + @test occursin("ParallelStencil.compute_parallel_ranges(Val", call) @test occursin("nb_parallel_indices", call) @test @prettystring(1, @parallel ranges f(A)) == "f(A, ParallelStencil.ParallelKernel.promote_ranges(ranges), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[1])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[2])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[3])))" @test @prettystring(1, @parallel nblocks nthreads f(A)) == "f(A, ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.compute_ranges(nblocks .* nthreads)), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.compute_ranges(nblocks .* nthreads)))[1])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.compute_ranges(nblocks .* nthreads)))[2])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.compute_ranges(nblocks .* nthreads)))[3])))" @test @prettystring(1, @parallel ranges nblocks nthreads f(A)) == "f(A, ParallelStencil.ParallelKernel.promote_ranges(ranges), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[1])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[2])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[3])))" call = @prettystring(1, @parallel stream=mystream f(A)) @test occursin("f(A, ParallelStencil.ParallelKernel.promote_ranges(", call) - @test occursin("ParallelStencil.ParallelKernel.get_ranges(A)", call) + @test occursin("ParallelStencil.compute_parallel_ranges(Val", call) @test occursin("nb_parallel_indices", call) # @test @prettystring(2, @parallel memopt=true f(A)) == "f(A, ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[1])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[2])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[3])))" - @test @prettystring(2, @parallel ranges memopt=true f(A)) == "f(A, ParallelStencil.ParallelKernel.promote_ranges(ranges), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[1])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[2])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[3])))" + call = @prettystring(2, @parallel ranges memopt=true f(A)) + @test occursin("ParallelStencil.compute_memopt_launch_config(Val", call) + @test occursin("ParallelStencil.compute_memopt_shmem(Val", call) + @test occursin("f(A, ParallelStencil.ParallelKernel.promote_ranges(ranges)", call) end; end; @testset "KernelAbstractions runtime reselection" begin From f563ac4295b188216d167359b0de49cbbd1eeb43 Mon Sep 17 00:00:00 2001 From: Samuel Omlin Date: Tue, 17 Mar 2026 12:35:26 +0100 Subject: [PATCH 4/7] Implement precomputation for parallel memory optimization arguments and refactor kernel call handling --- src/parallel.jl | 29 ++++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/src/parallel.jl b/src/parallel.jl index f384f11..7ab5f51 100644 --- a/src/parallel.jl +++ b/src/parallel.jl @@ -361,6 +361,15 @@ end ## @PARALLEL CALL FUNCTIONS +function precompute_parallel_memopt_arg(arg::Union{Symbol,Expr}, prefix::AbstractString) + if isa(arg, Symbol) + return Expr[], arg + else + arg_var = gensym(prefix) + return [:(local $arg_var = $arg)], arg_var + end +end + function parallel_call_memopt(caller::Module, metadata_expr::Union{Symbol,Expr}, ranges::Union{Symbol,Expr}, kernelcall::Expr, backend_kwargs_expr::Array, async::Bool; memopt::Bool=false, configcall::Expr=kernelcall) if haskey(backend_kwargs_expr, :shmem) @KeywordArgumentError("@parallel : keyword `shmem` is not allowed when memopt=true is set.") end package = get_package(caller) @@ -368,27 +377,25 @@ function parallel_call_memopt(caller::Module, metadata_expr::Union{Symbol,Expr}, nthreads_max_memopt = determine_nthreads_max_memopt(package) configcall_kwarg_expr = :(configcall=$configcall) numbertype = get_numbertype(caller) # not :(eltype($(optvars)[1])) # TODO: see how to obtain number type properly for each array: the type of the call call arguments corresponding to the optimization variables should be checked - launch_config_var = gensym("launch_config") nblocks_var = gensym("nblocks") nthreads_var = gensym("nthreads") shmem_var = gensym("shmem") - launch_config_expr = :(ParallelStencil.compute_memopt_launch_config(Val($metadata_expr.loopsizes), Val($metadata_expr.loopdim), Val($metadata_expr.stencilranges), $nthreads_x_max, $nthreads_max_memopt, $ranges)) + range_setup_exprs, range_arg = precompute_parallel_memopt_arg(ranges, "ranges") + nthreads_nblocks_expr = :(ParallelStencil.compute_memopt_nthreads_nblocks(Val($metadata_expr.loopsizes), Val($metadata_expr.loopdim), Val($metadata_expr.stencilranges), $nthreads_x_max, $nthreads_max_memopt, $range_arg)) shmem_expr = :(ParallelStencil.compute_memopt_shmem(Val($metadata_expr.shmem_optvars), Val($metadata_expr.use_shmemhalos), Val($metadata_expr.shmem_spans), Val($metadata_expr.shmem_dim1), Val($metadata_expr.shmem_dim2), $nthreads_var, $numbertype)) if async return quote - local $launch_config_var = $launch_config_expr - local $nblocks_var = $launch_config_var[1] - local $nthreads_var = $launch_config_var[2] + $(range_setup_exprs...) + local $nblocks_var, $nthreads_var = $nthreads_nblocks_expr local $shmem_var = $shmem_expr - @parallel_async memopt=false $configcall_kwarg_expr $ranges $nblocks_var $nthreads_var shmem=$shmem_var $(backend_kwargs_expr...) $kernelcall + @parallel_async memopt=false $configcall_kwarg_expr $range_arg $nblocks_var $nthreads_var shmem=$shmem_var $(backend_kwargs_expr...) $kernelcall end else return quote - local $launch_config_var = $launch_config_expr - local $nblocks_var = $launch_config_var[1] - local $nthreads_var = $launch_config_var[2] + $(range_setup_exprs...) + local $nblocks_var, $nthreads_var = $nthreads_nblocks_expr local $shmem_var = $shmem_expr - @parallel memopt=false $configcall_kwarg_expr $ranges $nblocks_var $nthreads_var shmem=$shmem_var $(backend_kwargs_expr...) $kernelcall + @parallel memopt=false $configcall_kwarg_expr $range_arg $nblocks_var $nthreads_var shmem=$shmem_var $(backend_kwargs_expr...) $kernelcall end end end @@ -488,7 +495,7 @@ end end end -@generated function compute_memopt_launch_config(::Val{loopsizes}, ::Val{loopdim}, ::Val{stencilranges}, nthreads_x_max, nthreads_max_memopt, ranges) where {loopsizes, loopdim, stencilranges} +@generated function compute_memopt_nthreads_nblocks(::Val{loopsizes}, ::Val{loopdim}, ::Val{stencilranges}, nthreads_x_max, nthreads_max_memopt, ranges) where {loopsizes, loopdim, stencilranges} return quote maxsize = cld.(length.(ParallelStencil.ParallelKernel.promote_ranges(ranges)), $loopsizes) nthreads = ParallelStencil.compute_nthreads_memopt(nthreads_x_max, nthreads_max_memopt, maxsize, $loopdim, $stencilranges) From 2833f2c3f7bbcb1f71768c7bc566a557dda40845 Mon Sep 17 00:00:00 2001 From: Samuel Omlin Date: Tue, 17 Mar 2026 12:35:34 +0100 Subject: [PATCH 5/7] Update tests in test_parallel to check for compute_memopt_nthreads_nblocks instead of compute_memopt_launch_config --- test/test_parallel.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_parallel.jl b/test/test_parallel.jl index cde26de..b845571 100644 --- a/test/test_parallel.jl +++ b/test/test_parallel.jl @@ -73,7 +73,7 @@ eval(:( call = @prettystring(2, @parallel memopt=true f(A)) # @test occursin("CUDA.@cuda blocks = ParallelStencil.ParallelKernel.compute_nblocks(cld.(length.(ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A))),", call) # NOTE: now it is a very long multi line expression; before it continued as follows: (1, 1, 16)), ParallelStencil.compute_nthreads_memopt(cld.(length.(ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A))), (1, 1, 16)), 3, (-1:1, -1:1, -1:1))) threads = ParallelStencil.compute_nthreads_memopt(cld.(length.(ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A))), (1, 1, 16)), 3, (-1:1, -1:1, -1:1)) stream = CUDA.stream() shmem = ((ParallelStencil.compute_nthreads_memopt(cld.(length.(ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A))), (1, 1, 16)), 3, (-1:1, -1:1, -1:1)))[1] + 3) * ((ParallelStencil.compute_nthreads_memopt(cld.(length.(ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A))), (1, 1, 16)), 3, (-1:1, -1:1, -1:1)))[2] + 3) * sizeof(Float64) f(A, ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[1])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[2])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[3])))", call) call = @prettystring(2, @parallel ranges memopt=true f(A)) - @test occursin("ParallelStencil.compute_memopt_launch_config(Val", call) + @test occursin("ParallelStencil.compute_memopt_nthreads_nblocks(Val", call) @test occursin("ParallelStencil.compute_memopt_shmem(Val", call) @test occursin("CUDA.@cuda blocks = var\"##nblocks", call) @test occursin("f(A, ParallelStencil.ParallelKernel.promote_ranges(ranges)", call) @@ -94,7 +94,7 @@ eval(:( call = @prettystring(2, @parallel memopt=true f(A)) # @test occursin("AMDGPU.@roc gridsize = ParallelStencil.ParallelKernel.compute_nblocks(cld.(length.(ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A))),", call) # NOTE: now it is a very long multi line expression; before it continued as follows: (1, 1, 16)), ParallelStencil.compute_nthreads_memopt(cld.(length.(ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A))), (1, 1, 16)), 3, (-1:1, -1:1, -1:1))) groupsize = ParallelStencil.compute_nthreads_memopt(cld.(length.(ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A))), (1, 1, 16)), 3, (-1:1, -1:1, -1:1)) stream = AMDGPU.stream() shmem = ((ParallelStencil.compute_nthreads_memopt(cld.(length.(ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A))), (1, 1, 16)), 3, (-1:1, -1:1, -1:1)))[1] + 3) * ((ParallelStencil.compute_nthreads_memopt(cld.(length.(ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A))), (1, 1, 16)), 3, (-1:1, -1:1, -1:1)))[2] + 3) * sizeof(Float64) f(A, ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[1])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[2])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[3])))", call) call = @prettystring(2, @parallel ranges memopt=true f(A)) - @test occursin("ParallelStencil.compute_memopt_launch_config(Val", call) + @test occursin("ParallelStencil.compute_memopt_nthreads_nblocks(Val", call) @test occursin("ParallelStencil.compute_memopt_shmem(Val", call) @test occursin("AMDGPU.@roc gridsize = var\"##nblocks", call) @test occursin("f(A, ParallelStencil.ParallelKernel.promote_ranges(ranges)", call) @@ -138,7 +138,7 @@ eval(:( @test occursin("nb_parallel_indices", call) # @test @prettystring(2, @parallel memopt=true f(A)) == "f(A, ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[1])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[2])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[3])))" call = @prettystring(2, @parallel ranges memopt=true f(A)) - @test occursin("ParallelStencil.compute_memopt_launch_config(Val", call) + @test occursin("ParallelStencil.compute_memopt_nthreads_nblocks(Val", call) @test occursin("ParallelStencil.compute_memopt_shmem(Val", call) @test occursin("f(A, ParallelStencil.ParallelKernel.promote_ranges(ranges)", call) end; From 6ebe602c51d0bc2dbdb32f2060c3c4bcbd3b857b Mon Sep 17 00:00:00 2001 From: Samuel Omlin Date: Tue, 17 Mar 2026 13:52:01 +0100 Subject: [PATCH 6/7] Refactor parallel memory optimization functions and update documentation for kernel call handling --- src/parallel.jl | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/src/parallel.jl b/src/parallel.jl index 7ab5f51..150173d 100644 --- a/src/parallel.jl +++ b/src/parallel.jl @@ -38,7 +38,7 @@ Declare the `kernelcall` parallel. The kernel will automatically be called as re Automatic computation of `ranges` for `@parallel ` is only possible if the number of parallel indices used by the kernel is equal to the number of dimensions of the highest-dimensional input arrays. Otherwise, specify the `ranges` manually with `@parallel ranges=... `. !!! note "Runtime hardware selection" - When KernelAbstractions is initialized, this wrapper consults [`current_hardware`](@ref) to determine the runtime hardware target. The symbol defaults to `:cpu` and can be switched to select other targets via [`select_hardware`](@ref). + When KernelAbstractions is chosen as the package for parallelization, this wrapper consults [`current_hardware`](@ref) to determine the runtime hardware target. The symbol defaults to `:cpu` and can be switched to select other targets via [`select_hardware`](@ref). # Arguments - `kernelcall`: a call to a kernel that is declared parallel. @@ -361,15 +361,6 @@ end ## @PARALLEL CALL FUNCTIONS -function precompute_parallel_memopt_arg(arg::Union{Symbol,Expr}, prefix::AbstractString) - if isa(arg, Symbol) - return Expr[], arg - else - arg_var = gensym(prefix) - return [:(local $arg_var = $arg)], arg_var - end -end - function parallel_call_memopt(caller::Module, metadata_expr::Union{Symbol,Expr}, ranges::Union{Symbol,Expr}, kernelcall::Expr, backend_kwargs_expr::Array, async::Bool; memopt::Bool=false, configcall::Expr=kernelcall) if haskey(backend_kwargs_expr, :shmem) @KeywordArgumentError("@parallel : keyword `shmem` is not allowed when memopt=true is set.") end package = get_package(caller) @@ -450,7 +441,7 @@ function compute_loopsize(package::Symbol) end -## FUNCTIONS TO COMPUTE NTHREADS, NBLOCKS +## FUNCTIONS TO COMPUTE NTHREADS, NBLOCKS, SHARED MEMORY SIZE AND RANGES function compute_nthreads_memopt(nthreads_x_max, nthreads_max_memopt, maxsize, loopdim, stencilranges) # This is a heuristic, which results typcially in (32,4,1) threads for a 3-D case. maxsize = promote_maxsize(maxsize) @@ -478,8 +469,6 @@ function get_ranges_memopt(nthreads_x_max, nthreads_max_memopt, loopdim, args... end -## FUNCTIONS TO COMPUTE SHARED MEMORY SIZE AND RANGES FOR MEMOPT - @generated function compute_memopt_shmem(::Val{optvars}, ::Val{use_shmemhalos}, ::Val{shmem_spans}, ::Val{shmem_dim1}, ::Val{shmem_dim2}, nthreads, ::Type{T}) where {optvars, use_shmemhalos, shmem_spans, shmem_dim1, shmem_dim2, T} terms = [:( (nthreads[$shmem_dim1] + $(getproperty(use_shmemhalos, A)) * $(getproperty(shmem_spans, A)[1])) * @@ -540,6 +529,15 @@ end end end +function precompute_parallel_memopt_arg(arg::Union{Symbol,Expr}, prefix::AbstractString) + if isa(arg, Symbol) + return Expr[], arg + else + arg_var = gensym(prefix) + return [:(local $arg_var = $arg)], arg_var + end +end + ## FUNCTIONS TO DEAL WITH MASKS (@WITHIN) AND INDICES From 2f0587e3b2a3fe206cce82f979e0be62156ce459 Mon Sep 17 00:00:00 2001 From: Samuel Omlin Date: Tue, 17 Mar 2026 13:52:09 +0100 Subject: [PATCH 7/7] Update documentation for kernelcall to clarify package selection for parallelization --- src/ParallelKernel/parallel.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ParallelKernel/parallel.jl b/src/ParallelKernel/parallel.jl index 1138d4c..a5c0579 100644 --- a/src/ParallelKernel/parallel.jl +++ b/src/ParallelKernel/parallel.jl @@ -15,7 +15,7 @@ Declare the `kernelcall` parallel. The kernel will automatically be called as re Automatic computation of `ranges` for `@parallel ` is only possible if the number of parallel indices used by the kernel is equal to the number of dimensions of the highest-dimensional input arrays. Otherwise, specify the `ranges` manually with `@parallel ranges=... `. !!! note "Runtime hardware selection" - When KernelAbstractions is initialized, this wrapper consults [`current_hardware`](@ref) to determine the runtime hardware target. The symbol defaults to `:cpu` and can be switched to select other targets via [`select_hardware`](@ref). + When KernelAbstractions is chosen as the package for parallelization, this wrapper consults [`current_hardware`](@ref) to determine the runtime hardware target. The symbol defaults to `:cpu` and can be switched to select other targets via [`select_hardware`](@ref). # Arguments - `kernelcall`: a call to a kernel that is declared parallel.