From 7120aad91089c6cce8b01afe7a9bcc6846cfcab9 Mon Sep 17 00:00:00 2001 From: Samuel Omlin Date: Fri, 13 Mar 2026 13:24:42 +0100 Subject: [PATCH 01/14] Refactor shared memory calculation in parallel_call_memopt to handle cases with single range dimensions --- src/parallel.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/parallel.jl b/src/parallel.jl index e07e211..863e80b 100644 --- a/src/parallel.jl +++ b/src/parallel.jl @@ -382,7 +382,7 @@ function parallel_call_memopt(caller::Module, ranges::Union{Symbol,Expr}, kernel dim1 = :(($loopdim==3) ? 1 : ($loopdim==2) ? 1 : 2) # TODO: to be determined if that is what is desired for loopdim 1 and 2. dim2 = :(($loopdim==3) ? 2 : ($loopdim==2) ? 3 : 3) # TODO: to be determined if that is what is desired for loopdim 1 and 2. A = gensym("A") - shmem = :(sum(($nthreads[$dim1]+$use_shmemhalos[$A]*(length($(stencilranges)[$A][$dim1])-1))*($nthreads[$dim2]+$use_shmemhalos[$A]*(length($(stencilranges)[$A][$dim2])-1))*sizeof($numbertype) for $A in $optvars)) + shmem = :((length(stencilranges[$A][$dim1]) > 1) || (length(stencilranges[$A][$dim2]) > 1) ? sum(($nthreads[$dim1]+$use_shmemhalos[$A]*(length($(stencilranges)[$A][$dim1])-1))*($nthreads[$dim2]+$use_shmemhalos[$A]*(length($(stencilranges)[$A][$dim2])-1))*sizeof($numbertype) for $A in $optvars) : 0) if (async) return :(@parallel_async memopt=false $configcall_kwarg_expr $ranges $nblocks $nthreads shmem=$shmem $(backend_kwargs_expr...) $kernelcall) #TODO: the package and numbertype will have to be passed here further once supported as kwargs else return :(@parallel memopt=false $configcall_kwarg_expr $ranges $nblocks $nthreads shmem=$shmem $(backend_kwargs_expr...) $kernelcall) #TODO: ... end From 1ceb52f639c1ce3696bf961be5152167d3ad21dd Mon Sep 17 00:00:00 2001 From: Samuel Omlin Date: Fri, 13 Mar 2026 15:26:09 +0100 Subject: [PATCH 02/14] Enhance memory optimization in memopt by adding shared memory optimization variables and updating metadata storage --- src/memopt.jl | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/src/memopt.jl b/src/memopt.jl index ccd3303..71ef15b 100644 --- a/src/memopt.jl +++ b/src/memopt.jl @@ -90,6 +90,7 @@ function memopt(metadata_module::Module, is_parallel_kernel::Bool, caller::Modul loopstart = minimum(values(loopentrys)) loopend = loopsize use_any_shmem = any(values(use_shmems)) + shmem_optvars = tuple((A for A in optvars if use_shmems[A])...) shmem_index_groups = define_shmem_index_groups(hx1s, hy1s, hx2s, hy2s, optvars, use_shmems, loopdim) shmem_vars = define_shmem_vars(oz_maxs, hx1s, hy1s, hx2s, hy2s, optvars, indices, use_shmems, use_shmem_xs, use_shmem_ys, shmem_index_groups, use_shmemhalos, use_shmemindices, loopdim) shmem_exprs = define_shmem_exprs(shmem_vars, loopdim) @@ -469,7 +470,7 @@ $(( # NOTE: the if statement is not needed here as we only deal with registers else @ArgumentError("memopt: only loopdim=3 is currently supported.") end - store_metadata(metadata_module, is_parallel_kernel, caller, offset_mins, offset_maxs, offsets, optvars, loopdim, loopsize, optranges, use_shmemhalos) + store_metadata(metadata_module, is_parallel_kernel, caller, offset_mins, offset_maxs, offsets, optvars, shmem_optvars, use_any_shmem, loopdim, loopsize, optranges, use_shmemhalos) # @show QuoteNode(ParallelKernel.simplify_varnames!(ParallelKernel.remove_linenumbernodes!(deepcopy(body)))) return body end @@ -1019,10 +1020,14 @@ function wrap_loop(index::Symbol, range::UnitRange, block::Expr; unroll=false) end end -function store_metadata(metadata_module::Module, is_parallel_kernel::Bool, caller::Module, offset_mins::Dict{Symbol, <:NTuple{3,Integer}}, offset_maxs::Dict{Symbol, <:NTuple{3,Integer}}, offsets::Dict{Symbol, Dict{Any, Any}}, optvars::NTuple{N,Symbol} where N, loopdim::Integer, loopsize::Integer, optranges::Dict{Any, Any}, use_shmemhalos) +function store_metadata(metadata_module::Module, is_parallel_kernel::Bool, caller::Module, offset_mins::Dict{Symbol, <:NTuple{3,Integer}}, offset_maxs::Dict{Symbol, <:NTuple{3,Integer}}, offsets::Dict{Symbol, Dict{Any, Any}}, optvars::NTuple{N,Symbol} where N, shmem_optvars::Tuple, use_any_shmem::Bool, loopdim::Integer, loopsize::Integer, optranges::Dict{Any, Any}, use_shmemhalos) memopt = true nonconst_metadata = get_nonconst_metadata(caller) stencilranges = NamedTuple(A => (offset_mins[A][1]:offset_maxs[A][1], offset_mins[A][2]:offset_maxs[A][2], offset_mins[A][3]:offset_maxs[A][3]) for A in optvars) + loopsizes = (loopdim==3) ? (1, 1, loopsize) : (loopdim==2) ? (1, loopsize, 1) : (loopsize, 1, 1) + shmem_dim1 = (loopdim==3) ? 1 : (loopdim==2) ? 1 : 2 + shmem_dim2 = (loopdim==3) ? 2 : (loopdim==2) ? 3 : 3 + shmem_spans = NamedTuple(A => (length(stencilranges[A][shmem_dim1]) - 1, length(stencilranges[A][shmem_dim2]) - 1) for A in optvars) if nonconst_metadata storeexpr = quote is_parallel_kernel = $is_parallel_kernel @@ -1031,9 +1036,15 @@ function store_metadata(metadata_module::Module, is_parallel_kernel::Bool, calle stencilranges = $stencilranges offsets = $offsets optvars = $optvars + shmem_optvars = $shmem_optvars + shmem_spans = $shmem_spans loopdim = $loopdim loopsize = $loopsize + loopsizes = $loopsizes + shmem_dim1 = $shmem_dim1 + shmem_dim2 = $shmem_dim2 optranges = $optranges + use_any_shmem = $use_any_shmem use_shmemhalos = $use_shmemhalos end else @@ -1044,9 +1055,15 @@ function store_metadata(metadata_module::Module, is_parallel_kernel::Bool, calle const stencilranges = $stencilranges const offsets = $offsets const optvars = $optvars + const shmem_optvars = $shmem_optvars + const shmem_spans = $shmem_spans const loopdim = $loopdim const loopsize = $loopsize + const loopsizes = $loopsizes + const shmem_dim1 = $shmem_dim1 + const shmem_dim2 = $shmem_dim2 const optranges = $optranges + const use_any_shmem = $use_any_shmem const use_shmemhalos = $use_shmemhalos end end From a013531520c30dd09545f41d0646b9385b81f592 Mon Sep 17 00:00:00 2001 From: Samuel Omlin Date: Fri, 13 Mar 2026 15:26:17 +0100 Subject: [PATCH 03/14] Refactor shared memory handling in parallel_call_memopt to improve clarity and optimize performance with new metadata variables --- src/parallel.jl | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/parallel.jl b/src/parallel.jl index 863e80b..cd7f1c8 100644 --- a/src/parallel.jl +++ b/src/parallel.jl @@ -369,20 +369,21 @@ function parallel_call_memopt(caller::Module, ranges::Union{Symbol,Expr}, kernel configcall_kwarg_expr = :(configcall=$configcall) metadata_call = create_metadata_call(configcall) metadata_module = metadata_call + loopdim = :($(metadata_module).loopdim) + loopsizes = :($(metadata_module).loopsizes) stencilranges = :($(metadata_module).stencilranges) use_shmemhalos = :($(metadata_module).use_shmemhalos) - optvars = :($(metadata_module).optvars) - loopdim = :($(metadata_module).loopdim) - loopsize = :($(metadata_module).loopsize) - loopsizes = :(($loopdim==3) ? (1, 1, $loopsize) : ($loopdim==2) ? (1, $loopsize, 1) : ($loopsize, 1, 1)) + use_any_shmem = :($(metadata_module).use_any_shmem) + shmem_dim1 = :($(metadata_module).shmem_dim1) + shmem_dim2 = :($(metadata_module).shmem_dim2) + shmem_optvars = :($(metadata_module).shmem_optvars) + shmem_spans = :($(metadata_module).shmem_spans) maxsize = :(cld.(length.(ParallelStencil.ParallelKernel.promote_ranges($ranges)), $loopsizes)) nthreads = :( ParallelStencil.compute_nthreads_memopt($nthreads_x_max, $nthreads_max_memopt, $maxsize, $loopdim, $stencilranges) ) nblocks = :( ParallelStencil.ParallelKernel.compute_nblocks($maxsize, $nthreads) ) numbertype = get_numbertype(caller) # not :(eltype($(optvars)[1])) # TODO: see how to obtain number type properly for each array: the type of the call call arguments corresponding to the optimization variables should be checked - dim1 = :(($loopdim==3) ? 1 : ($loopdim==2) ? 1 : 2) # TODO: to be determined if that is what is desired for loopdim 1 and 2. - dim2 = :(($loopdim==3) ? 2 : ($loopdim==2) ? 3 : 3) # TODO: to be determined if that is what is desired for loopdim 1 and 2. A = gensym("A") - shmem = :((length(stencilranges[$A][$dim1]) > 1) || (length(stencilranges[$A][$dim2]) > 1) ? sum(($nthreads[$dim1]+$use_shmemhalos[$A]*(length($(stencilranges)[$A][$dim1])-1))*($nthreads[$dim2]+$use_shmemhalos[$A]*(length($(stencilranges)[$A][$dim2])-1))*sizeof($numbertype) for $A in $optvars) : 0) + shmem = :($use_any_shmem ? sum(($nthreads[$shmem_dim1] + $use_shmemhalos[$A] * ($(shmem_spans)[$A][1])) * ($nthreads[$shmem_dim2] + $use_shmemhalos[$A] * ($(shmem_spans)[$A][2])) * sizeof($numbertype) for $A in $shmem_optvars) : 0) if (async) return :(@parallel_async memopt=false $configcall_kwarg_expr $ranges $nblocks $nthreads shmem=$shmem $(backend_kwargs_expr...) $kernelcall) #TODO: the package and numbertype will have to be passed here further once supported as kwargs else return :(@parallel memopt=false $configcall_kwarg_expr $ranges $nblocks $nthreads shmem=$shmem $(backend_kwargs_expr...) $kernelcall) #TODO: ... end From df47fc6f7a17187893cf0a5bac6beca228b31ced Mon Sep 17 00:00:00 2001 From: Samuel Omlin Date: Fri, 13 Mar 2026 15:26:26 +0100 Subject: [PATCH 04/14] Enhance metadata tests in test_parallel.jl to include new shared memory variables and update assertions for loop sizes and spans --- test/test_parallel.jl | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/test/test_parallel.jl b/test/test_parallel.jl index eb231c0..1703271 100644 --- a/test/test_parallel.jl +++ b/test/test_parallel.jl @@ -1263,18 +1263,33 @@ eval(:( metadata_symbols = sort(setdiff(names(metadata; all=true), names(metadata))) @test metadata isa Module @test length(names(metadata)) == 1 - @test metadata_symbols == [:is_parallel_kernel, :loopdim, :loopsize, :memopt, :nb_parallel_indices, :nonconst_metadata, :offsets, :optranges, :optvars, :stencilranges, :use_shmemhalos] + @test metadata_symbols == [:is_parallel_kernel, :loopdim, :loopsize, :loopsizes, :memopt, :nb_parallel_indices, :nonconst_metadata, :offsets, :optranges, :optvars, :shmem_dim1, :shmem_dim2, :shmem_optvars, :shmem_spans, :stencilranges, :use_any_shmem, :use_shmemhalos] @test metadata.is_parallel_kernel == false @test metadata.loopdim == 3 @test metadata.loopsize == 3 + @test metadata.loopsizes == (1, 1, 3) @test metadata.memopt == true @test metadata.nb_parallel_indices == 3 @test metadata.nonconst_metadata == true @test metadata.offsets[:B][(0, 0)][0] == 1 @test metadata.optranges[:B] == (0:0, 0:0, 0:0) @test metadata.optvars == (:B,) + @test metadata.shmem_dim1 == 1 + @test metadata.shmem_dim2 == 2 + @test metadata.shmem_optvars == () + @test metadata.shmem_spans == (B = (0, 0),) @test metadata.stencilranges == (B = (0:0, 0:0, 0:0),) + @test metadata.use_any_shmem == false @test metadata.use_shmemhalos[:B] == true + @parallel_indices (ix, iy, iz) memopt=true loopsize=3 optvars=B optranges=(B=(0:0,0:0,-1:1),) function metadata_memopt_zstencil_probe!(A, B, D) + A[ix, iy, iz] = B[ix, iy, iz-1] + B[ix, iy, iz] + B[ix, iy, iz+1] + D[ix, iy, iz, 1] + return + end + metadata_z = @metadata metadata_memopt_zstencil_probe!(A, B, D) + @test metadata_z.shmem_optvars == () + @test metadata_z.shmem_spans == (B = (0, 0),) + @test metadata_z.stencilranges == (B = (0:0, 0:0, -1:1),) + @test metadata_z.use_any_shmem == false @test all(Array(A) .== 0) end; end; @@ -1360,10 +1375,4 @@ eval(:( end; )) -eval(:( - @testset "$(basename(@__FILE__)) metadata (package: $(nameof($package)))" begin - - end; -)) - end == nothing || true; From 48b436c4122725e8d591433560219ae081a4637c Mon Sep 17 00:00:00 2001 From: Samuel Omlin Date: Fri, 13 Mar 2026 16:16:03 +0100 Subject: [PATCH 05/14] Enhance metadata tests in test_parallel.jl to validate shared memory optimization variables and their types for different scenarios --- test/test_parallel.jl | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/test/test_parallel.jl b/test/test_parallel.jl index 1703271..dc555e3 100644 --- a/test/test_parallel.jl +++ b/test/test_parallel.jl @@ -1277,6 +1277,7 @@ eval(:( @test metadata.shmem_dim1 == 1 @test metadata.shmem_dim2 == 2 @test metadata.shmem_optvars == () + @test metadata.shmem_optvars isa NTuple{0,Symbol} @test metadata.shmem_spans == (B = (0, 0),) @test metadata.stencilranges == (B = (0:0, 0:0, 0:0),) @test metadata.use_any_shmem == false @@ -1287,9 +1288,28 @@ eval(:( end metadata_z = @metadata metadata_memopt_zstencil_probe!(A, B, D) @test metadata_z.shmem_optvars == () + @test metadata_z.shmem_optvars isa NTuple{0,Symbol} @test metadata_z.shmem_spans == (B = (0, 0),) @test metadata_z.stencilranges == (B = (0:0, 0:0, -1:1),) @test metadata_z.use_any_shmem == false + @parallel_indices (ix, iy, iz) memopt=true loopsize=3 optvars=B function metadata_memopt_fullstencil_probe!(A, B, D) + A[ix, iy, iz] = B[ix-1, iy-1, iz-1] + B[ix, iy, iz] + B[ix+1, iy+1, iz+1] + D[ix, iy, iz, 1] + return + end + metadata_full = @metadata metadata_memopt_fullstencil_probe!(A, B, D) + @static if @isgpu($package) + @test metadata_full.shmem_optvars == (:B,) + @test metadata_full.shmem_optvars isa NTuple{1,Symbol} + @test metadata_full.shmem_spans == (B = (2, 2),) + @test metadata_full.stencilranges == (B = (-1:1, -1:1, -1:1),) + @test metadata_full.use_any_shmem == true + else + @test metadata_full.shmem_optvars == () + @test metadata_full.shmem_optvars isa NTuple{0,Symbol} + @test metadata_full.shmem_spans == (B = (0, 0),) + @test metadata_full.stencilranges == (B = (0:0, 0:0, 0:0),) + @test metadata_full.use_any_shmem == false + end @test all(Array(A) .== 0) end; end; From aa4d8d8fb6e936e25d4d3a3d58a582a376ce8535 Mon Sep 17 00:00:00 2001 From: Samuel Omlin Date: Fri, 13 Mar 2026 16:16:12 +0100 Subject: [PATCH 06/14] Refactor store_metadata function to specify type for shared memory optimization variables --- src/memopt.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/memopt.jl b/src/memopt.jl index 71ef15b..395ffcb 100644 --- a/src/memopt.jl +++ b/src/memopt.jl @@ -90,7 +90,7 @@ function memopt(metadata_module::Module, is_parallel_kernel::Bool, caller::Modul loopstart = minimum(values(loopentrys)) loopend = loopsize use_any_shmem = any(values(use_shmems)) - shmem_optvars = tuple((A for A in optvars if use_shmems[A])...) + shmem_optvars = tuple((A for A in optvars if use_shmems[A])...)::Tuple{Vararg{Symbol}} shmem_index_groups = define_shmem_index_groups(hx1s, hy1s, hx2s, hy2s, optvars, use_shmems, loopdim) shmem_vars = define_shmem_vars(oz_maxs, hx1s, hy1s, hx2s, hy2s, optvars, indices, use_shmems, use_shmem_xs, use_shmem_ys, shmem_index_groups, use_shmemhalos, use_shmemindices, loopdim) shmem_exprs = define_shmem_exprs(shmem_vars, loopdim) @@ -1020,7 +1020,7 @@ function wrap_loop(index::Symbol, range::UnitRange, block::Expr; unroll=false) end end -function store_metadata(metadata_module::Module, is_parallel_kernel::Bool, caller::Module, offset_mins::Dict{Symbol, <:NTuple{3,Integer}}, offset_maxs::Dict{Symbol, <:NTuple{3,Integer}}, offsets::Dict{Symbol, Dict{Any, Any}}, optvars::NTuple{N,Symbol} where N, shmem_optvars::Tuple, use_any_shmem::Bool, loopdim::Integer, loopsize::Integer, optranges::Dict{Any, Any}, use_shmemhalos) +function store_metadata(metadata_module::Module, is_parallel_kernel::Bool, caller::Module, offset_mins::Dict{Symbol, <:NTuple{3,Integer}}, offset_maxs::Dict{Symbol, <:NTuple{3,Integer}}, offsets::Dict{Symbol, Dict{Any, Any}}, optvars::NTuple{N,Symbol} where N, shmem_optvars::NTuple{M,Symbol} where M, use_any_shmem::Bool, loopdim::Integer, loopsize::Integer, optranges::Dict{Any, Any}, use_shmemhalos) memopt = true nonconst_metadata = get_nonconst_metadata(caller) stencilranges = NamedTuple(A => (offset_mins[A][1]:offset_maxs[A][1], offset_mins[A][2]:offset_maxs[A][2], offset_mins[A][3]:offset_maxs[A][3]) for A in optvars) From 9826628c0a7fa9eb52298536a0f32ecd263d7979 Mon Sep 17 00:00:00 2001 From: Samuel Omlin Date: Mon, 16 Mar 2026 15:18:49 +0100 Subject: [PATCH 07/14] reactant --- src/parallel.jl | 52 +++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 48 insertions(+), 4 deletions(-) diff --git a/src/parallel.jl b/src/parallel.jl index cd7f1c8..5983695 100644 --- a/src/parallel.jl +++ b/src/parallel.jl @@ -361,6 +361,38 @@ end ## @PARALLEL CALL FUNCTIONS +@generated function compute_memopt_shmem(::Val{optvars}, ::Val{use_shmemhalos}, ::Val{shmem_spans}, ::Val{shmem_dim1}, ::Val{shmem_dim2}, nthreads, ::Type{T}) where {optvars, use_shmemhalos, shmem_spans, shmem_dim1, shmem_dim2, T} + terms = [:( + (nthreads[$shmem_dim1] + $(getproperty(use_shmemhalos, A)) * $(getproperty(shmem_spans, A)[1])) * + (nthreads[$shmem_dim2] + $(getproperty(use_shmemhalos, A)) * $(getproperty(shmem_spans, A)[2])) * + sizeof(T) + ) for A in optvars] + if isempty(terms) + return :(0) + elseif length(terms) == 1 + return terms[1] + else + return Expr(:call, :+, terms...) + end +end + +@generated function compute_memopt_ranges(::Val{is_parallel_kernel}, ::Val{nb_parallel_indices}, ::Val{loopdim}, nthreads_x_max, nthreads_max_memopt, args...) where {is_parallel_kernel, nb_parallel_indices, loopdim} + if is_parallel_kernel + range_expr = :(ParallelStencil.get_ranges_memopt(nthreads_x_max, nthreads_max_memopt, $loopdim, args...)) + else + range_expr = :(ParallelStencil.ParallelKernel.get_ranges(args...)) + end + errorcall = :(ParallelStencil.@ArgumentError(ParallelStencil.ERRMSG_AUTOMATIC_RANGES_PARALLEL)) + return quote + nb_input_dims = ParallelStencil.get_nb_input_dims(args...) + nb_dims_match = (nb_input_dims == $nb_parallel_indices) + if nb_dims_match isa Bool + nb_dims_match || $errorcall + end + $range_expr + end +end + function parallel_call_memopt(caller::Module, ranges::Union{Symbol,Expr}, kernelcall::Expr, backend_kwargs_expr::Array, async::Bool; memopt::Bool=false, configcall::Expr=kernelcall) if haskey(backend_kwargs_expr, :shmem) @KeywordArgumentError("@parallel : keyword `shmem` is not allowed when memopt=true is set.") end package = get_package(caller) @@ -382,8 +414,12 @@ function parallel_call_memopt(caller::Module, ranges::Union{Symbol,Expr}, kernel nthreads = :( ParallelStencil.compute_nthreads_memopt($nthreads_x_max, $nthreads_max_memopt, $maxsize, $loopdim, $stencilranges) ) nblocks = :( ParallelStencil.ParallelKernel.compute_nblocks($maxsize, $nthreads) ) numbertype = get_numbertype(caller) # not :(eltype($(optvars)[1])) # TODO: see how to obtain number type properly for each array: the type of the call call arguments corresponding to the optimization variables should be checked - A = gensym("A") - shmem = :($use_any_shmem ? sum(($nthreads[$shmem_dim1] + $use_shmemhalos[$A] * ($(shmem_spans)[$A][1])) * ($nthreads[$shmem_dim2] + $use_shmemhalos[$A] * ($(shmem_spans)[$A][2])) * sizeof($numbertype) for $A in $shmem_optvars) : 0) + if get_nonconst_metadata(caller) + A = gensym("A") + shmem = :($use_any_shmem ? sum(($nthreads[$shmem_dim1] + $use_shmemhalos[$A] * ($(shmem_spans)[$A][1])) * ($nthreads[$shmem_dim2] + $use_shmemhalos[$A] * ($(shmem_spans)[$A][2])) * sizeof($numbertype) for $A in $shmem_optvars) : 0) + else + shmem = :(ParallelStencil.compute_memopt_shmem(Val($shmem_optvars), Val($use_shmemhalos), Val($shmem_spans), Val($shmem_dim1), Val($shmem_dim2), $nthreads, $numbertype)) + end if (async) return :(@parallel_async memopt=false $configcall_kwarg_expr $ranges $nblocks $nthreads shmem=$shmem $(backend_kwargs_expr...) $kernelcall) #TODO: the package and numbertype will have to be passed here further once supported as kwargs else return :(@parallel memopt=false $configcall_kwarg_expr $ranges $nblocks $nthreads shmem=$shmem $(backend_kwargs_expr...) $kernelcall) #TODO: ... end @@ -397,7 +433,12 @@ function parallel_call_memopt(caller::Module, kernelcall::Expr, backend_kwargs_e metadata_module = metadata_call loopdim = :($(metadata_module).loopdim) is_parallel_kernel = :($(metadata_module).is_parallel_kernel) - ranges = add_nb_parallel_indices_check(:( ($is_parallel_kernel) ? ParallelStencil.get_ranges_memopt($nthreads_x_max, $nthreads_max_memopt, $loopdim, $(configcall.args[2:end]...)) : ParallelStencil.ParallelKernel.get_ranges($(configcall.args[2:end]...))), configcall) + if get_nonconst_metadata(caller) + ranges = add_nb_parallel_indices_check(:( ($is_parallel_kernel) ? ParallelStencil.get_ranges_memopt($nthreads_x_max, $nthreads_max_memopt, $loopdim, $(configcall.args[2:end]...)) : ParallelStencil.ParallelKernel.get_ranges($(configcall.args[2:end]...))), configcall) + else + nb_parallel_indices = :($(metadata_module).nb_parallel_indices) + ranges = :(ParallelStencil.compute_memopt_ranges(Val($is_parallel_kernel), Val($nb_parallel_indices), Val($loopdim), $nthreads_x_max, $nthreads_max_memopt, $(configcall.args[2:end]...))) + end parallel_call_memopt(caller, ranges, kernelcall, backend_kwargs_expr, async; memopt=memopt, configcall=configcall) end @@ -553,7 +594,10 @@ function create_metadata_function(kernel::Expr, metadata_module::Module) # NOTE: kernelname = get_name(kernel) functionname = get_meta_function(kernelname) metadata_function = set_name(metadata_function, functionname) - set_body!(metadata_function, :(return $metadata_module)) + set_body!(metadata_function, quote + return $metadata_module + end) + Base.pushmeta!(metadata_function, :inline) return metadata_function end From c06e3269a3dbecc9d35ad4ff8cbd854925f2f71b Mon Sep 17 00:00:00 2001 From: Samuel Omlin Date: Mon, 16 Mar 2026 15:19:02 +0100 Subject: [PATCH 08/14] Add shared memory halo usage to store_metadata function --- src/memopt.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/memopt.jl b/src/memopt.jl index 395ffcb..dacc336 100644 --- a/src/memopt.jl +++ b/src/memopt.jl @@ -1024,6 +1024,7 @@ function store_metadata(metadata_module::Module, is_parallel_kernel::Bool, calle memopt = true nonconst_metadata = get_nonconst_metadata(caller) stencilranges = NamedTuple(A => (offset_mins[A][1]:offset_maxs[A][1], offset_mins[A][2]:offset_maxs[A][2], offset_mins[A][3]:offset_maxs[A][3]) for A in optvars) + use_shmemhalos = NamedTuple(A => use_shmemhalos[A] for A in optvars) loopsizes = (loopdim==3) ? (1, 1, loopsize) : (loopdim==2) ? (1, loopsize, 1) : (loopsize, 1, 1) shmem_dim1 = (loopdim==3) ? 1 : (loopdim==2) ? 1 : 2 shmem_dim2 = (loopdim==3) ? 2 : (loopdim==2) ? 3 : 3 From 9aa318708f89834098eb8c3a3b3e9780c4e83151 Mon Sep 17 00:00:00 2001 From: Samuel Omlin Date: Mon, 16 Mar 2026 16:57:11 +0100 Subject: [PATCH 09/14] Refactor create_metadata_function to use macro for inline return --- src/parallel.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/parallel.jl b/src/parallel.jl index 5983695..c8397dd 100644 --- a/src/parallel.jl +++ b/src/parallel.jl @@ -597,8 +597,7 @@ function create_metadata_function(kernel::Expr, metadata_module::Module) # NOTE: set_body!(metadata_function, quote return $metadata_module end) - Base.pushmeta!(metadata_function, :inline) - return metadata_function + return :(@inline $metadata_function) end function create_metadata_call(configcall::Expr) From 5b7335c24e2ab6f0b41460b8fb899e1728b698ae Mon Sep 17 00:00:00 2001 From: Samuel Omlin Date: Mon, 16 Mar 2026 16:57:23 +0100 Subject: [PATCH 10/14] Add nonconst_metadata option to @init_parallel_stencil in test_FiniteDifferences1D --- test/test_FiniteDifferences1D.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_FiniteDifferences1D.jl b/test/test_FiniteDifferences1D.jl index 4e1cd5a..2c95dca 100644 --- a/test/test_FiniteDifferences1D.jl +++ b/test/test_FiniteDifferences1D.jl @@ -36,7 +36,7 @@ eval(:( $(interpolate(:__padding__, (false, package!=PKG_POLYESTER), :( #TODO: this needs to be restored to (false, true) when Polyester supports padding. @testset "(padding=$__padding__)" begin @require !@is_initialized() - @init_parallel_stencil($package, $FloatDefault, 1, padding=__padding__) + @init_parallel_stencil($package, $FloatDefault, 1, padding=__padding__, nonconst_metadata=true) @require @is_initialized() nx = (9,) A = @IField(nx, @rand); From 9484710e7f079da1bc822ea61d3d50b9d19fad25 Mon Sep 17 00:00:00 2001 From: Samuel Omlin Date: Mon, 16 Mar 2026 16:57:31 +0100 Subject: [PATCH 11/14] Add nonconst_metadata option to @init_parallel_stencil in test_FiniteDifferences2D --- test/test_FiniteDifferences2D.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_FiniteDifferences2D.jl b/test/test_FiniteDifferences2D.jl index 17cf848..6b0a6c5 100644 --- a/test/test_FiniteDifferences2D.jl +++ b/test/test_FiniteDifferences2D.jl @@ -36,7 +36,7 @@ eval(:( $(interpolate(:__padding__, (false, package!=PKG_POLYESTER), :( #TODO: this needs to be restored to (false, true) when Polyester supports padding. @testset "(padding=$__padding__)" begin @require !@is_initialized() - @init_parallel_stencil($package, $FloatDefault, 2, padding=__padding__) + @init_parallel_stencil($package, $FloatDefault, 2, padding=__padding__, nonconst_metadata=true) @require @is_initialized() nxy = (9, 7) A = @IField(nxy, @rand); From 1484241a15626253e8653734ebddc43286283e0d Mon Sep 17 00:00:00 2001 From: Samuel Omlin Date: Mon, 16 Mar 2026 16:57:54 +0100 Subject: [PATCH 12/14] Add nonconst_metadata option to @init_parallel_stencil in test_kernel_language --- test/test_kernel_language.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_kernel_language.jl b/test/test_kernel_language.jl index 3bae9a9..72e0701 100644 --- a/test/test_kernel_language.jl +++ b/test/test_kernel_language.jl @@ -38,7 +38,7 @@ Base.retry_load_extensions() eval(:( @testset "$(basename(@__FILE__)) (package: $(nameof($package)))" begin @require !@is_initialized() - @init_parallel_stencil($package, $FloatDefault, 3) + @init_parallel_stencil($package, $FloatDefault, 3, nonconst_metadata=true) @require @is_initialized() @testset "Pass-through macro mapping" begin From 700765402ca44bbb3cf01caa3a8ba7db90dbd71e Mon Sep 17 00:00:00 2001 From: Samuel Omlin Date: Mon, 16 Mar 2026 16:58:02 +0100 Subject: [PATCH 13/14] Add nonconst_metadata option to @init_parallel_stencil in test_FiniteDifferences3D --- test/test_FiniteDifferences3D.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_FiniteDifferences3D.jl b/test/test_FiniteDifferences3D.jl index eb76dc0..88013cf 100644 --- a/test/test_FiniteDifferences3D.jl +++ b/test/test_FiniteDifferences3D.jl @@ -36,7 +36,7 @@ eval(:( $(interpolate(:__padding__, (false, package!=PKG_POLYESTER), :( #TODO: this needs to be restored to (false, true) when Polyester supports padding. @testset "(padding=$__padding__)" begin @require !@is_initialized() - @init_parallel_stencil($package, $FloatDefault, 3, padding=__padding__) + @init_parallel_stencil($package, $FloatDefault, 3, padding=__padding__, nonconst_metadata=true) @require @is_initialized() nxyz = (9, 7, 8) A = @IField(nxyz, @rand) From 8b9f4fddc9050daf75ead61494f03c917471a450 Mon Sep 17 00:00:00 2001 From: Samuel Omlin Date: Mon, 16 Mar 2026 16:58:08 +0100 Subject: [PATCH 14/14] Add nonconst_metadata option to @init_parallel_stencil in test_parallel --- test/test_parallel.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/test/test_parallel.jl b/test/test_parallel.jl index dc555e3..f2b278a 100644 --- a/test/test_parallel.jl +++ b/test/test_parallel.jl @@ -1011,7 +1011,7 @@ eval(:( @testset "3. parallel (with Fields)" begin @static if $package != $PKG_POLYESTER # TODO: this needs to be removed once Polyester supports padding @require !@is_initialized() - @init_parallel_stencil($package, $FloatDefault, 3, padding=true) + @init_parallel_stencil($package, $FloatDefault, 3, padding=true, nonconst_metadata=true) @require @is_initialized() @testset "padding" begin @testset "@parallel (3D, @all)" begin @@ -1063,7 +1063,7 @@ eval(:( @testset "4. global defaults" begin @testset "inbounds=true" begin @require !@is_initialized() - @init_parallel_stencil($package, $FloatDefault, 1, inbounds=true) + @init_parallel_stencil($package, $FloatDefault, 1, inbounds=true, nonconst_metadata=true) @require @is_initialized expansion = @prettystring(1, @parallel_indices (ix) inbounds=true f(A) = (2*A; return)) @test occursin("Base.@inbounds begin", expansion) @@ -1076,7 +1076,7 @@ eval(:( @testset "padding=true" begin @static if $package != $PKG_POLYESTER # TODO: this needs to be removed once Polyester supports padding @require !@is_initialized() - @init_parallel_stencil($package, $FloatDefault, 3, padding=true) + @init_parallel_stencil($package, $FloatDefault, 3, padding=true, nonconst_metadata=true) @require @is_initialized @testset "apply masks | handling padding (padding=true (globally))" begin expansion = @prettystring(1, @parallel sum!(A, B) = (@all(A) = @all(A) + @all(B); return)) @@ -1097,7 +1097,7 @@ eval(:( end; @testset "@parallel_indices (I...) (1D)" begin @require !@is_initialized() - @init_parallel_stencil($package, $FloatDefault, 1) + @init_parallel_stencil($package, $FloatDefault, 1, nonconst_metadata=true) @require @is_initialized A = @zeros(4*5*6) one = $FloatDefault(1) @@ -1111,7 +1111,7 @@ eval(:( end; @testset "@parallel_indices (I...) (2D)" begin @require !@is_initialized() - @init_parallel_stencil($package, $FloatDefault, 2) + @init_parallel_stencil($package, $FloatDefault, 2, nonconst_metadata=true) @require @is_initialized A = @zeros(4, 5*6) one = $FloatDefault(1) @@ -1125,7 +1125,7 @@ eval(:( end; @testset "@parallel_indices (I...) (3D)" begin @require !@is_initialized() - @init_parallel_stencil($package, $FloatDefault, 3) + @init_parallel_stencil($package, $FloatDefault, 3, nonconst_metadata=true) @require @is_initialized A = @zeros(4, 5, 6) one = $FloatDefault(1) @@ -1140,7 +1140,7 @@ eval(:( end; @testset "5. parallel macros (numbertype and ndims ommited)" begin @require !@is_initialized() - @init_parallel_stencil(package = $package) + @init_parallel_stencil(package = $package, nonconst_metadata=true) @require @is_initialized $(interpolate(:__T__, ARRAYTYPES, :( @testset "Data.__T__{T} to Data.Device.__T__{T}" begin