Skip to content
2 changes: 1 addition & 1 deletion src/ParallelKernel/parallel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Declare the `kernelcall` parallel. The kernel will automatically be called as re
Automatic computation of `ranges` for `@parallel <kernelcall>` is only possible if the number of parallel indices used by the kernel is equal to the number of dimensions of the highest-dimensional input arrays. Otherwise, specify the `ranges` manually with `@parallel ranges=... <kernelcall>`.

!!! note "Runtime hardware selection"
When KernelAbstractions is initialized, this wrapper consults [`current_hardware`](@ref) to determine the runtime hardware target. The symbol defaults to `:cpu` and can be switched to select other targets via [`select_hardware`](@ref).
When KernelAbstractions is chosen as the package for parallelization, this wrapper consults [`current_hardware`](@ref) to determine the runtime hardware target. The symbol defaults to `:cpu` and can be switched to select other targets via [`select_hardware`](@ref).

# Arguments
- `kernelcall`: a call to a kernel that is declared parallel.
Expand Down
177 changes: 108 additions & 69 deletions src/parallel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ Declare the `kernelcall` parallel. The kernel will automatically be called as re
Automatic computation of `ranges` for `@parallel <kernelcall>` is only possible if the number of parallel indices used by the kernel is equal to the number of dimensions of the highest-dimensional input arrays. Otherwise, specify the `ranges` manually with `@parallel ranges=... <kernelcall>`.

!!! note "Runtime hardware selection"
When KernelAbstractions is initialized, this wrapper consults [`current_hardware`](@ref) to determine the runtime hardware target. The symbol defaults to `:cpu` and can be switched to select other targets via [`select_hardware`](@ref).
When KernelAbstractions is chosen as the package for parallelization, this wrapper consults [`current_hardware`](@ref) to determine the runtime hardware target. The symbol defaults to `:cpu` and can be switched to select other targets via [`select_hardware`](@ref).

# Arguments
- `kernelcall`: a call to a kernel that is declared parallel.
Expand Down Expand Up @@ -191,7 +191,7 @@ function parallel(source::LineNumberNode, caller::Module, args::Union{Symbol,Exp
parallel_call_memopt(caller, posargs..., kernelarg, backend_kwargs_expr, async; kwargs...)
else
if isempty(posargs)
ranges = add_nb_parallel_indices_check(:(ParallelStencil.ParallelKernel.get_ranges($(configcall.args[2:end]...))), configcall)
ranges = :(ParallelStencil.compute_parallel_ranges(Val(($(create_metadata_call(configcall))).nb_parallel_indices), $(configcall.args[2:end]...)))
ParallelKernel.parallel(caller, ranges, backend_kwargs_expr..., configcall_kwarg_expr, kernelarg; package=package, async=async)
else
ParallelKernel.parallel(caller, posargs..., backend_kwargs_expr..., configcall_kwarg_expr, kernelarg; package=package, async=async)
Expand Down Expand Up @@ -361,67 +361,42 @@ end

## @PARALLEL CALL FUNCTIONS

@generated function compute_memopt_shmem(::Val{optvars}, ::Val{use_shmemhalos}, ::Val{shmem_spans}, ::Val{shmem_dim1}, ::Val{shmem_dim2}, nthreads, ::Type{T}) where {optvars, use_shmemhalos, shmem_spans, shmem_dim1, shmem_dim2, T}
terms = [:(
(nthreads[$shmem_dim1] + $(getproperty(use_shmemhalos, A)) * $(getproperty(shmem_spans, A)[1])) *
(nthreads[$shmem_dim2] + $(getproperty(use_shmemhalos, A)) * $(getproperty(shmem_spans, A)[2])) *
sizeof(T)
) for A in optvars]
if isempty(terms)
return :(0)
elseif length(terms) == 1
return terms[1]
else
return Expr(:call, :+, terms...)
end
end

@generated function compute_memopt_ranges(::Val{is_parallel_kernel}, ::Val{nb_parallel_indices}, ::Val{loopdim}, nthreads_x_max, nthreads_max_memopt, args...) where {is_parallel_kernel, nb_parallel_indices, loopdim}
if is_parallel_kernel
range_expr = :(ParallelStencil.get_ranges_memopt(nthreads_x_max, nthreads_max_memopt, $loopdim, args...))
else
range_expr = :(ParallelStencil.ParallelKernel.get_ranges(args...))
end
errorcall = :(ParallelStencil.@ArgumentError(ParallelStencil.ERRMSG_AUTOMATIC_RANGES_PARALLEL))
return quote
nb_input_dims = ParallelStencil.get_nb_input_dims(args...)
nb_dims_match = (nb_input_dims == $nb_parallel_indices)
if nb_dims_match isa Bool
nb_dims_match || $errorcall
end
$range_expr
end
end

function parallel_call_memopt(caller::Module, ranges::Union{Symbol,Expr}, kernelcall::Expr, backend_kwargs_expr::Array, async::Bool; memopt::Bool=false, configcall::Expr=kernelcall)
function parallel_call_memopt(caller::Module, metadata_expr::Union{Symbol,Expr}, ranges::Union{Symbol,Expr}, kernelcall::Expr, backend_kwargs_expr::Array, async::Bool; memopt::Bool=false, configcall::Expr=kernelcall)
if haskey(backend_kwargs_expr, :shmem) @KeywordArgumentError("@parallel <kernelcall>: keyword `shmem` is not allowed when memopt=true is set.") end
package = get_package(caller)
nthreads_x_max = ParallelKernel.determine_nthreads_x_max(package)
nthreads_max_memopt = determine_nthreads_max_memopt(package)
configcall_kwarg_expr = :(configcall=$configcall)
metadata_call = create_metadata_call(configcall)
metadata_module = metadata_call
loopdim = :($(metadata_module).loopdim)
loopsizes = :($(metadata_module).loopsizes)
stencilranges = :($(metadata_module).stencilranges)
use_shmemhalos = :($(metadata_module).use_shmemhalos)
use_any_shmem = :($(metadata_module).use_any_shmem)
shmem_dim1 = :($(metadata_module).shmem_dim1)
shmem_dim2 = :($(metadata_module).shmem_dim2)
shmem_optvars = :($(metadata_module).shmem_optvars)
shmem_spans = :($(metadata_module).shmem_spans)
maxsize = :(cld.(length.(ParallelStencil.ParallelKernel.promote_ranges($ranges)), $loopsizes))
nthreads = :( ParallelStencil.compute_nthreads_memopt($nthreads_x_max, $nthreads_max_memopt, $maxsize, $loopdim, $stencilranges) )
nblocks = :( ParallelStencil.ParallelKernel.compute_nblocks($maxsize, $nthreads) )
numbertype = get_numbertype(caller) # not :(eltype($(optvars)[1])) # TODO: see how to obtain number type properly for each array: the type of the call call arguments corresponding to the optimization variables should be checked
if get_nonconst_metadata(caller)
A = gensym("A")
shmem = :($use_any_shmem ? sum(($nthreads[$shmem_dim1] + $use_shmemhalos[$A] * ($(shmem_spans)[$A][1])) * ($nthreads[$shmem_dim2] + $use_shmemhalos[$A] * ($(shmem_spans)[$A][2])) * sizeof($numbertype) for $A in $shmem_optvars) : 0)
nblocks_var = gensym("nblocks")
nthreads_var = gensym("nthreads")
shmem_var = gensym("shmem")
range_setup_exprs, range_arg = precompute_parallel_memopt_arg(ranges, "ranges")
nthreads_nblocks_expr = :(ParallelStencil.compute_memopt_nthreads_nblocks(Val($metadata_expr.loopsizes), Val($metadata_expr.loopdim), Val($metadata_expr.stencilranges), $nthreads_x_max, $nthreads_max_memopt, $range_arg))
shmem_expr = :(ParallelStencil.compute_memopt_shmem(Val($metadata_expr.shmem_optvars), Val($metadata_expr.use_shmemhalos), Val($metadata_expr.shmem_spans), Val($metadata_expr.shmem_dim1), Val($metadata_expr.shmem_dim2), $nthreads_var, $numbertype))
if async
return quote
$(range_setup_exprs...)
local $nblocks_var, $nthreads_var = $nthreads_nblocks_expr
local $shmem_var = $shmem_expr
@parallel_async memopt=false $configcall_kwarg_expr $range_arg $nblocks_var $nthreads_var shmem=$shmem_var $(backend_kwargs_expr...) $kernelcall
end
else
shmem = :(ParallelStencil.compute_memopt_shmem(Val($shmem_optvars), Val($use_shmemhalos), Val($shmem_spans), Val($shmem_dim1), Val($shmem_dim2), $nthreads, $numbertype))
return quote
$(range_setup_exprs...)
local $nblocks_var, $nthreads_var = $nthreads_nblocks_expr
local $shmem_var = $shmem_expr
@parallel memopt=false $configcall_kwarg_expr $range_arg $nblocks_var $nthreads_var shmem=$shmem_var $(backend_kwargs_expr...) $kernelcall
end
end
if (async) return :(@parallel_async memopt=false $configcall_kwarg_expr $ranges $nblocks $nthreads shmem=$shmem $(backend_kwargs_expr...) $kernelcall) #TODO: the package and numbertype will have to be passed here further once supported as kwargs
else return :(@parallel memopt=false $configcall_kwarg_expr $ranges $nblocks $nthreads shmem=$shmem $(backend_kwargs_expr...) $kernelcall) #TODO: ...
end

function parallel_call_memopt(caller::Module, ranges::Union{Symbol,Expr}, kernelcall::Expr, backend_kwargs_expr::Array, async::Bool; memopt::Bool=false, configcall::Expr=kernelcall)
metadata_call = create_metadata_call(configcall)
metadata_var = gensym("metadata")
quote
local $metadata_var = $metadata_call
$(parallel_call_memopt(caller, metadata_var, ranges, kernelcall, backend_kwargs_expr, async; memopt=memopt, configcall=configcall))
end
end

Expand All @@ -430,16 +405,13 @@ function parallel_call_memopt(caller::Module, kernelcall::Expr, backend_kwargs_e
nthreads_x_max = ParallelKernel.determine_nthreads_x_max(package)
nthreads_max_memopt = determine_nthreads_max_memopt(package)
metadata_call = create_metadata_call(configcall)
metadata_module = metadata_call
loopdim = :($(metadata_module).loopdim)
is_parallel_kernel = :($(metadata_module).is_parallel_kernel)
if get_nonconst_metadata(caller)
ranges = add_nb_parallel_indices_check(:( ($is_parallel_kernel) ? ParallelStencil.get_ranges_memopt($nthreads_x_max, $nthreads_max_memopt, $loopdim, $(configcall.args[2:end]...)) : ParallelStencil.ParallelKernel.get_ranges($(configcall.args[2:end]...))), configcall)
else
nb_parallel_indices = :($(metadata_module).nb_parallel_indices)
ranges = :(ParallelStencil.compute_memopt_ranges(Val($is_parallel_kernel), Val($nb_parallel_indices), Val($loopdim), $nthreads_x_max, $nthreads_max_memopt, $(configcall.args[2:end]...)))
metadata_var = gensym("metadata")
ranges_var = gensym("ranges")
quote
local $metadata_var = $metadata_call
local $ranges_var = ParallelStencil.compute_memopt_ranges(Val($metadata_var.is_parallel_kernel), Val($metadata_var.nb_parallel_indices), Val($metadata_var.loopdim), $nthreads_x_max, $nthreads_max_memopt, $(configcall.args[2:end]...))
$(parallel_call_memopt(caller, metadata_var, ranges_var, kernelcall, backend_kwargs_expr, async; memopt=memopt, configcall=configcall))
end
parallel_call_memopt(caller, ranges, kernelcall, backend_kwargs_expr, async; memopt=memopt, configcall=configcall)
end


Expand Down Expand Up @@ -469,7 +441,7 @@ function compute_loopsize(package::Symbol)
end


## FUNCTIONS TO COMPUTE NTHREADS, NBLOCKS
## FUNCTIONS TO COMPUTE NTHREADS, NBLOCKS, SHARED MEMORY SIZE AND RANGES

function compute_nthreads_memopt(nthreads_x_max, nthreads_max_memopt, maxsize, loopdim, stencilranges) # This is a heuristic, which results typcially in (32,4,1) threads for a 3-D case.
maxsize = promote_maxsize(maxsize)
Expand Down Expand Up @@ -497,6 +469,76 @@ function get_ranges_memopt(nthreads_x_max, nthreads_max_memopt, loopdim, args...
end


@generated function compute_memopt_shmem(::Val{optvars}, ::Val{use_shmemhalos}, ::Val{shmem_spans}, ::Val{shmem_dim1}, ::Val{shmem_dim2}, nthreads, ::Type{T}) where {optvars, use_shmemhalos, shmem_spans, shmem_dim1, shmem_dim2, T}
terms = [:(
(nthreads[$shmem_dim1] + $(getproperty(use_shmemhalos, A)) * $(getproperty(shmem_spans, A)[1])) *
(nthreads[$shmem_dim2] + $(getproperty(use_shmemhalos, A)) * $(getproperty(shmem_spans, A)[2])) *
sizeof(T)
) for A in optvars]
if isempty(terms)
return :(0)
elseif length(terms) == 1
return terms[1]
else
return Expr(:call, :+, terms...)
end
end

@generated function compute_memopt_nthreads_nblocks(::Val{loopsizes}, ::Val{loopdim}, ::Val{stencilranges}, nthreads_x_max, nthreads_max_memopt, ranges) where {loopsizes, loopdim, stencilranges}
return quote
maxsize = cld.(length.(ParallelStencil.ParallelKernel.promote_ranges(ranges)), $loopsizes)
nthreads = ParallelStencil.compute_nthreads_memopt(nthreads_x_max, nthreads_max_memopt, maxsize, $loopdim, $stencilranges)
nblocks = ParallelStencil.ParallelKernel.compute_nblocks(maxsize, nthreads)
(nblocks, nthreads)
end
end

@generated function check_nb_parallel_indices(::Val{nb_parallel_indices}, args...) where {nb_parallel_indices}
errorcall = :(ParallelStencil.@ArgumentError(ParallelStencil.ERRMSG_AUTOMATIC_RANGES_PARALLEL))
return quote
nb_input_dims = ParallelStencil.get_nb_input_dims(args...)
nb_dims_match = (nb_input_dims == $nb_parallel_indices)
if nb_dims_match isa Bool
nb_dims_match || $errorcall
end
nothing
end
end

@generated function compute_parallel_ranges(::Val{nb_parallel_indices}, args...) where {nb_parallel_indices}
return quote
ParallelStencil.check_nb_parallel_indices(Val($nb_parallel_indices), args...)
ParallelStencil.ParallelKernel.get_ranges(args...)
end
end

@generated function compute_memopt_ranges(::Val{is_parallel_kernel}, ::Val{nb_parallel_indices}, ::Val{loopdim}, nthreads_x_max, nthreads_max_memopt, args...) where {is_parallel_kernel, nb_parallel_indices, loopdim}
if is_parallel_kernel
range_expr = :(ParallelStencil.get_ranges_memopt(nthreads_x_max, nthreads_max_memopt, $loopdim, args...))
else
range_expr = :(ParallelStencil.ParallelKernel.get_ranges(args...))
end
errorcall = :(ParallelStencil.@ArgumentError(ParallelStencil.ERRMSG_AUTOMATIC_RANGES_PARALLEL))
return quote
nb_input_dims = ParallelStencil.get_nb_input_dims(args...)
nb_dims_match = (nb_input_dims == $nb_parallel_indices)
if nb_dims_match isa Bool
nb_dims_match || $errorcall
end
$range_expr
end
end

function precompute_parallel_memopt_arg(arg::Union{Symbol,Expr}, prefix::AbstractString)
if isa(arg, Symbol)
return Expr[], arg
else
arg_var = gensym(prefix)
return [:(local $arg_var = $arg)], arg_var
end
end


## FUNCTIONS TO DEAL WITH MASKS (@WITHIN) AND INDICES

is_splatarg(x) = isa(x,Expr) && (x.head == :...)
Expand Down Expand Up @@ -705,10 +747,7 @@ end

function add_nb_parallel_indices_check(ranges::Union{Symbol,Expr}, configcall::Expr)
metadata_call = create_metadata_call(configcall)
nb_parallel_indices = :(($metadata_call).nb_parallel_indices)
nb_input_dims = :(ParallelStencil.get_nb_input_dims($(configcall.args[2:end]...)))
errorcall = :(ParallelStencil.@ArgumentError(ParallelStencil.ERRMSG_AUTOMATIC_RANGES_PARALLEL))
return :(($nb_input_dims != $nb_parallel_indices && $errorcall; $ranges))
return :(ParallelStencil.check_nb_parallel_indices(Val(($metadata_call).nb_parallel_indices), $(configcall.args[2:end]...)); $ranges)
end

get_nb_input_dims(args...) = maximum((get_nb_input_dims(arg) for arg in args); init=1)
Expand Down
Loading
Loading