Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions src/memopt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ function memopt(metadata_module::Module, is_parallel_kernel::Bool, caller::Modul
loopstart = minimum(values(loopentrys))
loopend = loopsize
use_any_shmem = any(values(use_shmems))
shmem_optvars = tuple((A for A in optvars if use_shmems[A])...)::Tuple{Vararg{Symbol}}
shmem_index_groups = define_shmem_index_groups(hx1s, hy1s, hx2s, hy2s, optvars, use_shmems, loopdim)
shmem_vars = define_shmem_vars(oz_maxs, hx1s, hy1s, hx2s, hy2s, optvars, indices, use_shmems, use_shmem_xs, use_shmem_ys, shmem_index_groups, use_shmemhalos, use_shmemindices, loopdim)
shmem_exprs = define_shmem_exprs(shmem_vars, loopdim)
Expand Down Expand Up @@ -469,7 +470,7 @@ $(( # NOTE: the if statement is not needed here as we only deal with registers
else
@ArgumentError("memopt: only loopdim=3 is currently supported.")
end
store_metadata(metadata_module, is_parallel_kernel, caller, offset_mins, offset_maxs, offsets, optvars, loopdim, loopsize, optranges, use_shmemhalos)
store_metadata(metadata_module, is_parallel_kernel, caller, offset_mins, offset_maxs, offsets, optvars, shmem_optvars, use_any_shmem, loopdim, loopsize, optranges, use_shmemhalos)
# @show QuoteNode(ParallelKernel.simplify_varnames!(ParallelKernel.remove_linenumbernodes!(deepcopy(body))))
return body
end
Expand Down Expand Up @@ -1019,10 +1020,15 @@ function wrap_loop(index::Symbol, range::UnitRange, block::Expr; unroll=false)
end
end

function store_metadata(metadata_module::Module, is_parallel_kernel::Bool, caller::Module, offset_mins::Dict{Symbol, <:NTuple{3,Integer}}, offset_maxs::Dict{Symbol, <:NTuple{3,Integer}}, offsets::Dict{Symbol, Dict{Any, Any}}, optvars::NTuple{N,Symbol} where N, loopdim::Integer, loopsize::Integer, optranges::Dict{Any, Any}, use_shmemhalos)
function store_metadata(metadata_module::Module, is_parallel_kernel::Bool, caller::Module, offset_mins::Dict{Symbol, <:NTuple{3,Integer}}, offset_maxs::Dict{Symbol, <:NTuple{3,Integer}}, offsets::Dict{Symbol, Dict{Any, Any}}, optvars::NTuple{N,Symbol} where N, shmem_optvars::NTuple{M,Symbol} where M, use_any_shmem::Bool, loopdim::Integer, loopsize::Integer, optranges::Dict{Any, Any}, use_shmemhalos)
memopt = true
nonconst_metadata = get_nonconst_metadata(caller)
stencilranges = NamedTuple(A => (offset_mins[A][1]:offset_maxs[A][1], offset_mins[A][2]:offset_maxs[A][2], offset_mins[A][3]:offset_maxs[A][3]) for A in optvars)
use_shmemhalos = NamedTuple(A => use_shmemhalos[A] for A in optvars)
loopsizes = (loopdim==3) ? (1, 1, loopsize) : (loopdim==2) ? (1, loopsize, 1) : (loopsize, 1, 1)
shmem_dim1 = (loopdim==3) ? 1 : (loopdim==2) ? 1 : 2
shmem_dim2 = (loopdim==3) ? 2 : (loopdim==2) ? 3 : 3
shmem_spans = NamedTuple(A => (length(stencilranges[A][shmem_dim1]) - 1, length(stencilranges[A][shmem_dim2]) - 1) for A in optvars)
if nonconst_metadata
storeexpr = quote
is_parallel_kernel = $is_parallel_kernel
Expand All @@ -1031,9 +1037,15 @@ function store_metadata(metadata_module::Module, is_parallel_kernel::Bool, calle
stencilranges = $stencilranges
offsets = $offsets
optvars = $optvars
shmem_optvars = $shmem_optvars
shmem_spans = $shmem_spans
loopdim = $loopdim
loopsize = $loopsize
loopsizes = $loopsizes
shmem_dim1 = $shmem_dim1
shmem_dim2 = $shmem_dim2
optranges = $optranges
use_any_shmem = $use_any_shmem
use_shmemhalos = $use_shmemhalos
end
else
Expand All @@ -1044,9 +1056,15 @@ function store_metadata(metadata_module::Module, is_parallel_kernel::Bool, calle
const stencilranges = $stencilranges
const offsets = $offsets
const optvars = $optvars
const shmem_optvars = $shmem_optvars
const shmem_spans = $shmem_spans
const loopdim = $loopdim
const loopsize = $loopsize
const loopsizes = $loopsizes
const shmem_dim1 = $shmem_dim1
const shmem_dim2 = $shmem_dim2
const optranges = $optranges
const use_any_shmem = $use_any_shmem
const use_shmemhalos = $use_shmemhalos
end
end
Expand Down
66 changes: 55 additions & 11 deletions src/parallel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,38 @@ end

## @PARALLEL CALL FUNCTIONS

@generated function compute_memopt_shmem(::Val{optvars}, ::Val{use_shmemhalos}, ::Val{shmem_spans}, ::Val{shmem_dim1}, ::Val{shmem_dim2}, nthreads, ::Type{T}) where {optvars, use_shmemhalos, shmem_spans, shmem_dim1, shmem_dim2, T}
terms = [:(
(nthreads[$shmem_dim1] + $(getproperty(use_shmemhalos, A)) * $(getproperty(shmem_spans, A)[1])) *
(nthreads[$shmem_dim2] + $(getproperty(use_shmemhalos, A)) * $(getproperty(shmem_spans, A)[2])) *
sizeof(T)
) for A in optvars]
if isempty(terms)
return :(0)
elseif length(terms) == 1
return terms[1]
else
return Expr(:call, :+, terms...)
end
end

@generated function compute_memopt_ranges(::Val{is_parallel_kernel}, ::Val{nb_parallel_indices}, ::Val{loopdim}, nthreads_x_max, nthreads_max_memopt, args...) where {is_parallel_kernel, nb_parallel_indices, loopdim}
if is_parallel_kernel
range_expr = :(ParallelStencil.get_ranges_memopt(nthreads_x_max, nthreads_max_memopt, $loopdim, args...))
else
range_expr = :(ParallelStencil.ParallelKernel.get_ranges(args...))
end
errorcall = :(ParallelStencil.@ArgumentError(ParallelStencil.ERRMSG_AUTOMATIC_RANGES_PARALLEL))
return quote
nb_input_dims = ParallelStencil.get_nb_input_dims(args...)
nb_dims_match = (nb_input_dims == $nb_parallel_indices)
if nb_dims_match isa Bool
nb_dims_match || $errorcall
end
$range_expr
end
end

function parallel_call_memopt(caller::Module, ranges::Union{Symbol,Expr}, kernelcall::Expr, backend_kwargs_expr::Array, async::Bool; memopt::Bool=false, configcall::Expr=kernelcall)
if haskey(backend_kwargs_expr, :shmem) @KeywordArgumentError("@parallel <kernelcall>: keyword `shmem` is not allowed when memopt=true is set.") end
package = get_package(caller)
Expand All @@ -369,20 +401,25 @@ function parallel_call_memopt(caller::Module, ranges::Union{Symbol,Expr}, kernel
configcall_kwarg_expr = :(configcall=$configcall)
metadata_call = create_metadata_call(configcall)
metadata_module = metadata_call
loopdim = :($(metadata_module).loopdim)
loopsizes = :($(metadata_module).loopsizes)
stencilranges = :($(metadata_module).stencilranges)
use_shmemhalos = :($(metadata_module).use_shmemhalos)
optvars = :($(metadata_module).optvars)
loopdim = :($(metadata_module).loopdim)
loopsize = :($(metadata_module).loopsize)
loopsizes = :(($loopdim==3) ? (1, 1, $loopsize) : ($loopdim==2) ? (1, $loopsize, 1) : ($loopsize, 1, 1))
use_any_shmem = :($(metadata_module).use_any_shmem)
shmem_dim1 = :($(metadata_module).shmem_dim1)
shmem_dim2 = :($(metadata_module).shmem_dim2)
shmem_optvars = :($(metadata_module).shmem_optvars)
shmem_spans = :($(metadata_module).shmem_spans)
maxsize = :(cld.(length.(ParallelStencil.ParallelKernel.promote_ranges($ranges)), $loopsizes))
nthreads = :( ParallelStencil.compute_nthreads_memopt($nthreads_x_max, $nthreads_max_memopt, $maxsize, $loopdim, $stencilranges) )
nblocks = :( ParallelStencil.ParallelKernel.compute_nblocks($maxsize, $nthreads) )
numbertype = get_numbertype(caller) # not :(eltype($(optvars)[1])) # TODO: see how to obtain number type properly for each array: the type of the call call arguments corresponding to the optimization variables should be checked
dim1 = :(($loopdim==3) ? 1 : ($loopdim==2) ? 1 : 2) # TODO: to be determined if that is what is desired for loopdim 1 and 2.
dim2 = :(($loopdim==3) ? 2 : ($loopdim==2) ? 3 : 3) # TODO: to be determined if that is what is desired for loopdim 1 and 2.
A = gensym("A")
shmem = :(sum(($nthreads[$dim1]+$use_shmemhalos[$A]*(length($(stencilranges)[$A][$dim1])-1))*($nthreads[$dim2]+$use_shmemhalos[$A]*(length($(stencilranges)[$A][$dim2])-1))*sizeof($numbertype) for $A in $optvars))
if get_nonconst_metadata(caller)
A = gensym("A")
shmem = :($use_any_shmem ? sum(($nthreads[$shmem_dim1] + $use_shmemhalos[$A] * ($(shmem_spans)[$A][1])) * ($nthreads[$shmem_dim2] + $use_shmemhalos[$A] * ($(shmem_spans)[$A][2])) * sizeof($numbertype) for $A in $shmem_optvars) : 0)
else
shmem = :(ParallelStencil.compute_memopt_shmem(Val($shmem_optvars), Val($use_shmemhalos), Val($shmem_spans), Val($shmem_dim1), Val($shmem_dim2), $nthreads, $numbertype))
end
if (async) return :(@parallel_async memopt=false $configcall_kwarg_expr $ranges $nblocks $nthreads shmem=$shmem $(backend_kwargs_expr...) $kernelcall) #TODO: the package and numbertype will have to be passed here further once supported as kwargs
else return :(@parallel memopt=false $configcall_kwarg_expr $ranges $nblocks $nthreads shmem=$shmem $(backend_kwargs_expr...) $kernelcall) #TODO: ...
end
Expand All @@ -396,7 +433,12 @@ function parallel_call_memopt(caller::Module, kernelcall::Expr, backend_kwargs_e
metadata_module = metadata_call
loopdim = :($(metadata_module).loopdim)
is_parallel_kernel = :($(metadata_module).is_parallel_kernel)
ranges = add_nb_parallel_indices_check(:( ($is_parallel_kernel) ? ParallelStencil.get_ranges_memopt($nthreads_x_max, $nthreads_max_memopt, $loopdim, $(configcall.args[2:end]...)) : ParallelStencil.ParallelKernel.get_ranges($(configcall.args[2:end]...))), configcall)
if get_nonconst_metadata(caller)
ranges = add_nb_parallel_indices_check(:( ($is_parallel_kernel) ? ParallelStencil.get_ranges_memopt($nthreads_x_max, $nthreads_max_memopt, $loopdim, $(configcall.args[2:end]...)) : ParallelStencil.ParallelKernel.get_ranges($(configcall.args[2:end]...))), configcall)
else
nb_parallel_indices = :($(metadata_module).nb_parallel_indices)
ranges = :(ParallelStencil.compute_memopt_ranges(Val($is_parallel_kernel), Val($nb_parallel_indices), Val($loopdim), $nthreads_x_max, $nthreads_max_memopt, $(configcall.args[2:end]...)))
end
parallel_call_memopt(caller, ranges, kernelcall, backend_kwargs_expr, async; memopt=memopt, configcall=configcall)
end

Expand Down Expand Up @@ -552,8 +594,10 @@ function create_metadata_function(kernel::Expr, metadata_module::Module) # NOTE:
kernelname = get_name(kernel)
functionname = get_meta_function(kernelname)
metadata_function = set_name(metadata_function, functionname)
set_body!(metadata_function, :(return $metadata_module))
return metadata_function
set_body!(metadata_function, quote
return $metadata_module
end)
return :(@inline $metadata_function)
end

function create_metadata_call(configcall::Expr)
Expand Down
2 changes: 1 addition & 1 deletion test/test_FiniteDifferences1D.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ eval(:(
$(interpolate(:__padding__, (false, package!=PKG_POLYESTER), :( #TODO: this needs to be restored to (false, true) when Polyester supports padding.
@testset "(padding=$__padding__)" begin
@require !@is_initialized()
@init_parallel_stencil($package, $FloatDefault, 1, padding=__padding__)
@init_parallel_stencil($package, $FloatDefault, 1, padding=__padding__, nonconst_metadata=true)
@require @is_initialized()
nx = (9,)
A = @IField(nx, @rand);
Expand Down
2 changes: 1 addition & 1 deletion test/test_FiniteDifferences2D.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ eval(:(
$(interpolate(:__padding__, (false, package!=PKG_POLYESTER), :( #TODO: this needs to be restored to (false, true) when Polyester supports padding.
@testset "(padding=$__padding__)" begin
@require !@is_initialized()
@init_parallel_stencil($package, $FloatDefault, 2, padding=__padding__)
@init_parallel_stencil($package, $FloatDefault, 2, padding=__padding__, nonconst_metadata=true)
@require @is_initialized()
nxy = (9, 7)
A = @IField(nxy, @rand);
Expand Down
2 changes: 1 addition & 1 deletion test/test_FiniteDifferences3D.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ eval(:(
$(interpolate(:__padding__, (false, package!=PKG_POLYESTER), :( #TODO: this needs to be restored to (false, true) when Polyester supports padding.
@testset "(padding=$__padding__)" begin
@require !@is_initialized()
@init_parallel_stencil($package, $FloatDefault, 3, padding=__padding__)
@init_parallel_stencil($package, $FloatDefault, 3, padding=__padding__, nonconst_metadata=true)
@require @is_initialized()
nxyz = (9, 7, 8)
A = @IField(nxyz, @rand)
Expand Down
2 changes: 1 addition & 1 deletion test/test_kernel_language.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ Base.retry_load_extensions()
eval(:(
@testset "$(basename(@__FILE__)) (package: $(nameof($package)))" begin
@require !@is_initialized()
@init_parallel_stencil($package, $FloatDefault, 3)
@init_parallel_stencil($package, $FloatDefault, 3, nonconst_metadata=true)
@require @is_initialized()

@testset "Pass-through macro mapping" begin
Expand Down
Loading
Loading