Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lib/intrinsics/src/integer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ for gentype in generic_integer_types

@device_function popcount(x::$gentype) = @builtin_ccall("popcount", $gentype, ($gentype,), x)

@device_override Base.bitreverse(x::$gentype) = @builtin_ccall("bit_reverse", $gentype, ($gentype,), x)

@device_function mad24(x::$gentype, y::$gentype, z::$gentype) = @builtin_ccall("mad24", $gentype, ($gentype, $gentype, $gentype), x, y, z)
@device_function mul24(x::$gentype, y::$gentype) = @builtin_ccall("mul24", $gentype, ($gentype, $gentype), x, y)

Expand Down
43 changes: 41 additions & 2 deletions src/compiler/compilation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,38 @@ end

## compiler implementation (cache, configure, compile, and link)

# Mapping from OpenCL device extensions to SPIR-V extensions that should be
# automatically enabled during compilation. This allows features like float atomics
# to work transparently through KernelAbstractions without manual extension specification.
const OPENCL_TO_SPIRV_EXTENSIONS = Dict{String, Vector{String}}(
"cl_ext_float_atomics" => [
"SPV_EXT_shader_atomic_float_add",
"SPV_EXT_shader_atomic_float_min_max",
],
"cl_khr_extended_bit_ops" => [
"SPV_KHR_bit_instructions",
],
)

"""
spirv_extensions_for_device(dev::cl.Device) -> Vector{String}

Query the device's OpenCL extensions and return the corresponding SPIR-V extensions
that should be enabled for compilation.
"""
function spirv_extensions_for_device(dev::cl.Device)
spirv_exts = String["SPV_KHR_bit_instructions"]
device_exts = dev.extensions

for (cl_ext, spirv_ext_list) in OPENCL_TO_SPIRV_EXTENSIONS
if cl_ext in device_exts
append!(spirv_exts, spirv_ext_list)
end
end

return spirv_exts
end

# cache of compilation caches, per context
const _compiler_caches = Dict{cl.Context, Dict{Any, Any}}()
function compiler_cache(ctx::cl.Context)
Expand All @@ -127,12 +159,19 @@ function compiler_config(dev::cl.Device; kwargs...)
end
return config
end
@noinline function _compiler_config(dev; kernel=true, name=nothing, always_inline=false, kwargs...)
@noinline function _compiler_config(dev; kernel=true, name=nothing, always_inline=false,
extensions::Vector{String}=String[], kwargs...)
supports_fp16 = "cl_khr_fp16" in dev.extensions
supports_fp64 = "cl_khr_fp64" in dev.extensions

# Auto-detect SPIR-V extensions from device capabilities and merge with
# any explicitly requested extensions
auto_extensions = spirv_extensions_for_device(dev)
all_extensions = unique!(vcat(extensions, auto_extensions))

# create GPUCompiler objects
target = SPIRVCompilerTarget(; supports_fp16, supports_fp64, validate=true, kwargs...)
target = SPIRVCompilerTarget(; supports_fp16, supports_fp64, validate=true,
extensions=all_extensions, kwargs...)
params = OpenCLCompilerParams()
CompilerConfig(target, params; kernel, name, always_inline)
end
Expand Down
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Atomix = "a9b6321e-bd34-4604-b9c9-b65b8de01458"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
Expand All @@ -8,6 +9,7 @@ InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
OpenCL = "08131aa3-fb12-5dee-8b74-c09406e224a2"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
REPL = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
Expand Down
2 changes: 1 addition & 1 deletion test/setup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ function runtests(f, name, platform_filter)
end

# some tests require native execution capabilities
requires_il = name in ["atomics", "execution", "intrinsics", "kernelabstractions"] ||
requires_il = name in ["atomics", "execution", "intrinsics", "kernelabstractions", "spirv_extensions"] ||
startswith(name, "gpuarrays/") || startswith(name, "device/")

ex = quote
Expand Down
65 changes: 65 additions & 0 deletions test/spirv_extensions.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
using KernelAbstractions
using Atomix: Atomix

@testset "spirv_extensions" begin

@testset "bitreverse KernelAbstractions kernel" begin
@kernel function bitreverse_ka!(out, inp)
i = @index(Global)
@inbounds out[i] = bitreverse(inp[i])
end

@testset "$T" for T in [Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64]
N = 64
inp = CLArray(rand(T, N))
out = similar(inp)
bitreverse_ka!(OpenCLBackend())(out, inp; ndrange=N)
synchronize(OpenCLBackend())

@test Array(out) == bitreverse.(Array(inp))
end
end

@testset "atomic float accumulation without manual extensions" begin
# The auto-spirv-extensions feature detects cl_ext_float_atomics and enables
# SPV_EXT_shader_atomic_float_add automatically. Previously this required
# passing extensions=["SPV_EXT_shader_atomic_float_add"] to @opencl manually.
#
# We test with a concurrent accumulation pattern where multiple work-items
# write to the same output locations. Without atomics this would race;
# with Atomix.@atomic (which emits OpAtomicFAddEXT) the result must be exact.
if "cl_ext_float_atomics" in cl.device().extensions
@kernel function atomic_accum_kernel!(out, arr)
i, j = @index(Global, NTuple)
for k in 1:size(out, 1)
Atomix.@atomic out[k, i] += arr[i, j]
end
end

@testset "$T" for T in [Float32, Float64]
if T == Float64 && !("cl_khr_fp64" in cl.device().extensions)
continue
end

M, N = 32, 64
img = zeros(T, M, N)
img[5:15, 5:15] .= one(T)
img[20:30, 20:30] .= T(2)

cl_img = CLArray(img)
out = KernelAbstractions.zeros(OpenCLBackend(), T, M, N)
atomic_accum_kernel!(OpenCLBackend())(out, cl_img; ndrange=(M, N))
synchronize(OpenCLBackend())

# Each out[k, i] = sum(img[i, :]) — accumulate row i across all columns
out_host = Array(out)
expected = zeros(T, M, N)
for i in 1:M
expected[:, i] .= sum(img[i, :])
end
@test out_host ≈ expected
end
end
end

end