Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/ParallelKernel/parallel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@ const PARALLEL_DOC = """
@parallel (...) configcall=... backendkwargs... kernelcall
@parallel ∇=... ad_mode=... ad_annotations=... (...) backendkwargs... kernelcall

Declare the `kernelcall` parallel. The kernel will automatically be called as required by the package for parallelization selected with [`@init_parallel_kernel`](@ref). Synchronizes at the end of the call (if a stream is given via keyword arguments, then it synchronizes only this stream). The keyword argument `∇` triggers a parallel call to the gradient kernel instead of the kernel itself. The automatic differentiation is performed with the package Enzyme.jl (refer to the corresponding documentation for Enzyme-specific terms used below); Enzyme needs to be imported before ParallelKernel in order to have it load the corresponding extension.
Declare the `kernelcall` parallel. The kernel will automatically be called as required by the package for parallelization selected with [`@init_parallel_kernel`](@ref) (however, see below the note on automatic computation of `ranges`). Synchronizes at the end of the call (if a stream is given via keyword arguments, then it synchronizes only this stream). The keyword argument `∇` triggers a parallel call to the gradient kernel instead of the kernel itself. The automatic differentiation is performed with the package Enzyme.jl (refer to the corresponding documentation for Enzyme-specific terms used below); Enzyme needs to be imported before ParallelKernel in order to have it load the corresponding extension.

!!! note "Automatic computation of `ranges`"
Automatic computation of `ranges` for `@parallel <kernelcall>` is only possible if the number of parallel indices used by the kernel is equal to the number of dimensions of the highest-dimensional input arrays. Otherwise, specify the `ranges` manually with `@parallel ranges=... <kernelcall>`.

!!! note "Runtime hardware selection"
When KernelAbstractions is initialized, this wrapper consults [`current_hardware`](@ref) to determine the runtime hardware target. The symbol defaults to `:cpu` and can be switched to select other targets via [`select_hardware`](@ref).
Expand Down
83 changes: 75 additions & 8 deletions src/parallel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ See also: [`@init_parallel_stencil`](@ref)
@parallel (...) memopt=... configcall=... backendkwargs... kernelcall
@parallel ∇=... ad_mode=... ad_annotations=... (...) memopt=... backendkwargs... kernelcall

Declare the `kernelcall` parallel. The kernel will automatically be called as required by the package for parallelization selected with [`@init_parallel_kernel`](@ref). Synchronizes at the end of the call (if a stream is given via keyword arguments, then it synchronizes only this stream). The keyword argument `∇` triggers a parallel call to the gradient kernel instead of the kernel itself. The automatic differentiation is performed with the package Enzyme.jl (refer to the corresponding documentation for Enzyme-specific terms used below); Enzyme needs to be imported before ParallelStencil in order to have it load the corresponding extension.
Declare the `kernelcall` parallel. The kernel will automatically be called as required by the package for parallelization selected with [`@init_parallel_kernel`](@ref) (however, see below the note on automatic computation of `ranges`). Synchronizes at the end of the call (if a stream is given via keyword arguments, then it synchronizes only this stream). The keyword argument `∇` triggers a parallel call to the gradient kernel instead of the kernel itself. The automatic differentiation is performed with the package Enzyme.jl (refer to the corresponding documentation for Enzyme-specific terms used below); Enzyme needs to be imported before ParallelStencil in order to have it load the corresponding extension.

!!! note "Automatic computation of `ranges`"
Automatic computation of `ranges` for `@parallel <kernelcall>` is only possible if the number of parallel indices used by the kernel is equal to the number of dimensions of the highest-dimensional input arrays. Otherwise, specify the `ranges` manually with `@parallel ranges=... <kernelcall>`.

!!! note "Runtime hardware selection"
When KernelAbstractions is initialized, this wrapper consults [`current_hardware`](@ref) to determine the runtime hardware target. The symbol defaults to `:cpu` and can be switched to select other targets via [`select_hardware`](@ref).
Expand Down Expand Up @@ -92,6 +95,9 @@ $(replace(ParallelKernel.PARALLEL_ASYNC_DOC, "@init_parallel_kernel" => "@init_p
macro parallel_async(args...) check_initialized(__module__); checkargs_parallel(args...); esc(parallel_async(__source__, __module__, args...)); end


const ERRMSG_AUTOMATIC_RANGES_PARALLEL = "@parallel <kernelcall>: the ranges needed for the kernel call cannot be automatically computed (less parallel indices than dimensions of the input arrays); specify the ranges manually with @parallel ranges=... <kernelcall>."


## MACROS FORCING PACKAGE, IGNORING INITIALIZATION

macro parallel_cuda(args...) check_initialized(__module__); checkargs_parallel(args...); esc(parallel(__source__, __module__, args...; package=PKG_CUDA)); end
Expand Down Expand Up @@ -184,7 +190,12 @@ function parallel(source::LineNumberNode, caller::Module, args::Union{Symbol,Exp
if (length(posargs) > 1) @ArgumentError("maximum one positional argument (ranges) is allowed in a @parallel memopt=true call.") end
parallel_call_memopt(caller, posargs..., kernelarg, backend_kwargs_expr, async; kwargs...)
else
ParallelKernel.parallel(caller, posargs..., backend_kwargs_expr..., configcall_kwarg_expr, kernelarg; package=package, async=async)
if isempty(posargs)
ranges = add_nb_parallel_indices_check(:(ParallelStencil.ParallelKernel.get_ranges($(configcall.args[2:end]...))), configcall)
ParallelKernel.parallel(caller, ranges, backend_kwargs_expr..., configcall_kwarg_expr, kernelarg; package=package, async=async)
else
ParallelKernel.parallel(caller, posargs..., backend_kwargs_expr..., configcall_kwarg_expr, kernelarg; package=package, async=async)
end
end
end
end
Expand Down Expand Up @@ -213,17 +224,24 @@ function parallel_indices(source::LineNumberNode, caller::Module, args::Union{Sy
else
metadata_module, metadata_function = kwargs.metadata_module, kwargs.metadata_function
end
if !haskey(kwargs, :metadata_module)
store_metadata(metadata_module, caller, determine_nb_parallel_indices(caller, get_body(kernelarg), extract_tuple(indices_expr)))
end
inbounds = haskey(kwargs, :inbounds) ? kwargs.inbounds : get_inbounds(caller)
padding = haskey(kwargs, :padding) ? kwargs.padding : get_padding(caller)
memopt = haskey(kwargs, :memopt) ? kwargs.memopt : get_memopt(caller)
if memopt
quote
$(parallel_indices_memopt(metadata_module, metadata_function, is_parallel_kernel, caller, package, posargs..., kernelarg; kwargs...)) #TODO: the package and numbertype will have to be passed here further once supported as kwargs (currently removed from call: package, numbertype, )
$metadata_function
$(parallel_indices_memopt(metadata_module, metadata_function, is_parallel_kernel, caller, package, posargs..., kernelarg; kwargs...)) #TODO: the package and numbertype will have to be passed here further once supported as kwargs (currently removed from call: package, numbertype, )
end
else
kwargs_expr = (:(inbounds=$inbounds), :(padding=$padding))
ParallelKernel.parallel_indices(caller, posargs..., kwargs_expr..., kernelarg; package=package)
kernel = ParallelKernel.parallel_indices(caller, posargs..., kwargs_expr..., kernelarg; package=package)
quote
$metadata_function
$kernel
end
end
end
end
Expand Down Expand Up @@ -288,6 +306,9 @@ function parallel_kernel(metadata_module::Module, metadata_function::Expr, calle
inbounds = haskey(kwargs, :inbounds) ? kwargs.inbounds : get_inbounds(caller)
padding = haskey(kwargs, :padding) ? kwargs.padding : get_padding(caller)
memopt = haskey(kwargs, :memopt) ? kwargs.memopt : get_memopt(caller)
if !haskey(kwargs, :metadata_module)
store_metadata(metadata_module, caller, ndims)
end
indices = get_indices_expr(ndims).args
indices_dir = get_indices_dir_expr(ndims).args
body = get_body(kernel)
Expand Down Expand Up @@ -323,14 +344,17 @@ function parallel_kernel(metadata_module::Module, metadata_function::Expr, calle
if memopt
expanded_kernel = macroexpand(caller, kernel)
quote
$(parallel_indices_memopt(metadata_module, metadata_function, is_parallel_kernel, caller, package, get_indices_expr(ndims), expanded_kernel; kwargs...)) #TODO: the package and numbertype will have to be passed here further once supported as kwargs (currently removed from call: package, numbertype, )
$metadata_function
$(parallel_indices_memopt(metadata_module, metadata_function, is_parallel_kernel, caller, package, get_indices_expr(ndims), expanded_kernel; kwargs...)) #TODO: the package and numbertype will have to be passed here further once supported as kwargs (currently removed from call: package, numbertype, )
end
else
if package == PKG_KERNELABSTRACTIONS
kernel = :(ParallelStencil.ParallelKernel.@ka_kernel $kernel)
end
return kernel # TODO: later could be here called parallel_indices instead of adding the threadids etc above.
return quote
$metadata_function
$kernel
end # TODO: later could be here called parallel_indices instead of adding the threadids etc above.
end
end

Expand Down Expand Up @@ -372,7 +396,7 @@ function parallel_call_memopt(caller::Module, kernelcall::Expr, backend_kwargs_e
metadata_module = metadata_call
loopdim = :($(metadata_module).loopdim)
is_parallel_kernel = :($(metadata_module).is_parallel_kernel)
ranges = :( ($is_parallel_kernel) ? ParallelStencil.get_ranges_memopt($nthreads_x_max, $nthreads_max_memopt, $loopdim, $(configcall.args[2:end]...)) : ParallelStencil.ParallelKernel.get_ranges($(configcall.args[2:end]...)))
ranges = add_nb_parallel_indices_check(:( ($is_parallel_kernel) ? ParallelStencil.get_ranges_memopt($nthreads_x_max, $nthreads_max_memopt, $loopdim, $(configcall.args[2:end]...)) : ParallelStencil.ParallelKernel.get_ranges($(configcall.args[2:end]...))), configcall)
parallel_call_memopt(caller, ranges, kernelcall, backend_kwargs_expr, async; memopt=memopt, configcall=configcall)
end

Expand Down Expand Up @@ -493,11 +517,21 @@ function get_indices_dir_expr(ndims::Integer)
end
end

function determine_nb_parallel_indices(caller::Module, body::Expr, indices)
body = macroexpand(caller, body)
used_indices = filter(index -> inexpr_walk(body, index), indices)
if 0 < length(used_indices) < length(indices)
unused_indices = filter(index -> !inexpr_walk(body, index), indices)
@ArgumentError("@parallel_indices: all parallel indices must be used in the kernel body (unused indices: $(join(string.(unused_indices), ", "))).")
end
return length(indices)
end


## FUNCTIONS TO CREATE METADATA STORAGE

function create_metadata_storage(source::LineNumberNode, caller::Module, kernel::Expr)
kernelid = get_kernelid(get_name(kernel), source.file, source.line)
kernelid = get_kernelid(kernel, source.file, source.line)
create_module(caller, MOD_METADATA_PS)
topmodule = @eval(caller, $MOD_METADATA_PS)
create_module(topmodule, kernelid)
Expand Down Expand Up @@ -529,7 +563,22 @@ function create_metadata_call(configcall::Expr)
return metadata_call
end

function store_metadata(metadata_module::Module, caller::Module, nb_parallel_indices::Integer)
nonconst_metadata = get_nonconst_metadata(caller)
if nonconst_metadata || isdefined(metadata_module, :nb_parallel_indices)
storeexpr = quote
nb_parallel_indices = $nb_parallel_indices
end
else
storeexpr = quote
const nb_parallel_indices = $nb_parallel_indices
end
end
@eval(metadata_module, $storeexpr)
end

get_kernelid(kernelname, file, line) = Symbol("$(kernelname)_$(file)_$(line)")
get_kernelid(kernel::Expr, file, line) = Symbol("$(get_kernelid(get_name(kernel), file, line))_$(hash(string(kernel)))")
get_meta_function(kernelname) = Symbol("$(META_FUNCTION_PREFIX)$(GENSYM_SEPARATOR)$(kernelname)")


Expand Down Expand Up @@ -606,3 +655,21 @@ function create_onthefly_macro(caller, m, expr, var, indices, indices_dir)
@eval(caller, $m_macro)
return
end


## FUNCTIONS TO CHECK THE AUTOMATIC DETERMINATION OF RANGES AND NB_PARALLEL_INDICES

function add_nb_parallel_indices_check(ranges::Union{Symbol,Expr}, configcall::Expr)
metadata_call = create_metadata_call(configcall)
nb_parallel_indices = :(($metadata_call).nb_parallel_indices)
nb_input_dims = :(ParallelStencil.get_nb_input_dims($(configcall.args[2:end]...)))
errorcall = :(ParallelStencil.@ArgumentError(ParallelStencil.ERRMSG_AUTOMATIC_RANGES_PARALLEL))
return :(($nb_input_dims != $nb_parallel_indices && $errorcall; $ranges))
end

get_nb_input_dims(args...) = maximum((get_nb_input_dims(arg) for arg in args); init=1)
get_nb_input_dims(t::T) where T<:Union{Tuple,NamedTuple} = get_nb_input_dims(t...)
get_nb_input_dims(A::AbstractArray) = ndims(A)
get_nb_input_dims(A::SubArray) = ndims(A.parent)
get_nb_input_dims(a::Number) = 1
get_nb_input_dims(x) = isbitstype(typeof(x)) ? 1 : @ArgumentError("automatic detection of ranges not possible in @parallel <kernelcall>: some kernel arguments are neither arrays nor scalars nor any other bitstypes nor (named) tuple containing any of the former. Specify ranges or nthreads and nblocks manually.")
8 changes: 8 additions & 0 deletions src/shared.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,11 @@ check_nonconst_metadata(nonconst_metadata) = ( if !isa(nonconst_metadata, Bool)
## FUNCTIONS/MACROS FOR DIVERSE SYNTAX SUGAR

hasmeta_PS(caller::Module) = isdefined(caller, MOD_METADATA_PS)


## FUNCTIONS AND MACROS FOR UNIT TESTS

macro metadata(kernelcall)
if !is_call(kernelcall) @ArgumentError("@metadata: the argument must be a kernel call (obtained: $kernelcall).") end
return esc(create_metadata_call(kernelcall))
end
54 changes: 52 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,65 @@
# NOTE: This file contains many parts that are copied from the file runtests.jl from the Package MPI.jl.
push!(LOAD_PATH, "../src")

const PREIMPORT_STDERR_SUPPRESSION_RULES = (
(name="Metal OS support warnings", start=r"^┌ Error: Metal\.jl is only supported on macOS$", stop=r"^└ @ Metal .*$"),
)

const ANSI_ESCAPE_REGEX = r"\e\[[0-9;]*m"

function filter_stderr_content(text::AbstractString; rules=STDERR_SUPPRESSION_RULES)
isempty(text) && return text
lines = split(text, '\n'; keepempty=true)
filtered = String[]
active_stop = nothing
for line in lines
match_line = replace(line, ANSI_ESCAPE_REGEX => "")
if !isnothing(active_stop)
if occursin(active_stop, match_line)
active_stop = nothing
end
continue
end
matched = false
for rule in rules
if occursin(rule.start, match_line)
active_stop = rule.stop
matched = true
break
end
end
matched || push!(filtered, line)
end
return join(filtered, '\n')
end

function import_with_filtered_stderr(modulename::Symbol; rules=PREIMPORT_STDERR_SUPPRESSION_RULES)
mktemp() do path, io
redirect_stderr(io) do
@eval import $(modulename)
end
flush(io)
close(io)
filtered = filter_stderr_content(read(path, String); rules=rules)
isempty(filtered) || print(Base.stderr, filtered)
end
end

import ParallelStencil # Precompile it.
import ParallelStencil: SUPPORTED_PACKAGES, PKG_CUDA, PKG_AMDGPU, PKG_METAL, PKG_KERNELABSTRACTIONS
@static if (PKG_CUDA in SUPPORTED_PACKAGES) import CUDA end
@static if (PKG_AMDGPU in SUPPORTED_PACKAGES) import AMDGPU end
@static if (PKG_METAL in SUPPORTED_PACKAGES) import Metal end
@static if (PKG_METAL in SUPPORTED_PACKAGES) import_with_filtered_stderr(:Metal) end
@static if (PKG_KERNELABSTRACTIONS in SUPPORTED_PACKAGES) import KernelAbstractions end # KernelAbstractions does not require extra harness env vars beyond the existing CUDA/AMDGPU settings.

excludedfiles = [ "test_excluded.jl", "test_incremental_compilation.jl", "test_revise.jl"]; # TODO: test_incremental_compilation has to be deactivated until Polyester support released

const STDERR_SUPPRESSION_RULES = (
(name="metadata method overwrite warnings", start=r"^WARNING: Method definition .*###META.* overwritten.*$", stop=nothing),
(name="[T]Data module replacement warnings", start=r"^WARNING: replacing module [T]?Data\.$", stop=nothing),
(name="Metal OS support warnings", start=r"^┌ Error: Metal\.jl is only supported on macOS$", stop=r"^└ @ Metal .*$"),
)

function runtests(testfiles=String[]; stop_on_fail=false)
exename = joinpath(Sys.BINDIR, Base.julia_exename())
testdir = pwd()
Expand Down Expand Up @@ -61,7 +111,7 @@ function runtests(testfiles=String[]; stop_on_fail=false)
stdout_content = read(stdout_path, String)
stderr_content = read(stderr_path, String)
print(stdout_content)
print(Base.stderr, stderr_content)
print(Base.stderr, filter_stderr_content(stderr_content))
catch ex
println("Test Abort: a system-level exception occurred while running the test file $f :")
println(ex)
Expand Down
Loading
Loading