From 0c73a89a1e13312af78625ba80efcd34ebe96611 Mon Sep 17 00:00:00 2001 From: SimonDanisch Date: Thu, 8 Jan 2026 13:33:03 -0100 Subject: [PATCH 1/4] Auto-detect SPIR-V extensions from device capabilities When using KernelAbstractions with features like Atomix.@atomic on Float32, the required SPIR-V extensions (e.g., SPV_EXT_shader_atomic_float_add) were not being enabled because KernelAbstractions calls @opencl internally without the `extensions` parameter. This change automatically queries the device's OpenCL extensions and maps them to the corresponding SPIR-V extensions that should be enabled during compilation. Currently supported mappings: - cl_ext_float_atomics -> SPV_EXT_shader_atomic_float_add, SPV_EXT_shader_atomic_float_min_max The mapping dictionary can be extended as needed for other extension pairs. Fixes the issue where Float32 atomics fail through KernelAbstractions even when the device supports cl_ext_float_atomics. Co-Authored-By: Claude Opus 4.5 --- src/compiler/compilation.jl | 40 +++++++++++++++++++++++++++++++++++-- 1 file changed, 38 insertions(+), 2 deletions(-) diff --git a/src/compiler/compilation.jl b/src/compiler/compilation.jl index 03637b97..a8086cfb 100644 --- a/src/compiler/compilation.jl +++ b/src/compiler/compilation.jl @@ -104,6 +104,35 @@ end ## compiler implementation (cache, configure, compile, and link) +# Mapping from OpenCL device extensions to SPIR-V extensions that should be +# automatically enabled during compilation. This allows features like float atomics +# to work transparently through KernelAbstractions without manual extension specification. +const OPENCL_TO_SPIRV_EXTENSIONS = Dict{String, Vector{String}}( + "cl_ext_float_atomics" => [ + "SPV_EXT_shader_atomic_float_add", + "SPV_EXT_shader_atomic_float_min_max", + ], +) + +""" + spirv_extensions_for_device(dev::cl.Device) -> Vector{String} + +Query the device's OpenCL extensions and return the corresponding SPIR-V extensions +that should be enabled for compilation. +""" +function spirv_extensions_for_device(dev::cl.Device) + spirv_exts = String[] + device_exts = dev.extensions + + for (cl_ext, spirv_ext_list) in OPENCL_TO_SPIRV_EXTENSIONS + if cl_ext in device_exts + append!(spirv_exts, spirv_ext_list) + end + end + + return spirv_exts +end + # cache of compilation caches, per context const _compiler_caches = Dict{cl.Context, Dict{Any, Any}}() function compiler_cache(ctx::cl.Context) @@ -127,12 +156,19 @@ function compiler_config(dev::cl.Device; kwargs...) end return config end -@noinline function _compiler_config(dev; kernel=true, name=nothing, always_inline=false, kwargs...) +@noinline function _compiler_config(dev; kernel=true, name=nothing, always_inline=false, + extensions::Vector{String}=String[], kwargs...) supports_fp16 = "cl_khr_fp16" in dev.extensions supports_fp64 = "cl_khr_fp64" in dev.extensions + # Auto-detect SPIR-V extensions from device capabilities and merge with + # any explicitly requested extensions + auto_extensions = spirv_extensions_for_device(dev) + all_extensions = unique(vcat(extensions, auto_extensions)) + # create GPUCompiler objects - target = SPIRVCompilerTarget(; supports_fp16, supports_fp64, validate=true, kwargs...) + target = SPIRVCompilerTarget(; supports_fp16, supports_fp64, validate=true, + extensions=all_extensions, kwargs...) params = OpenCLCompilerParams() CompilerConfig(target, params; kernel, name, always_inline) end From 24cbd430106603609082229c9ff82ecea6fe5125 Mon Sep 17 00:00:00 2001 From: SimonDanisch Date: Wed, 28 Jan 2026 15:05:05 +0100 Subject: [PATCH 2/4] add SPV_KHR_bit_instructions --- lib/intrinsics/src/integer.jl | 1 + src/compiler/compilation.jl | 3 +++ 2 files changed, 4 insertions(+) diff --git a/lib/intrinsics/src/integer.jl b/lib/intrinsics/src/integer.jl index 7e36f02c..0d756d53 100644 --- a/lib/intrinsics/src/integer.jl +++ b/lib/intrinsics/src/integer.jl @@ -36,6 +36,7 @@ for gentype in generic_integer_types @device_function popcount(x::$gentype) = @builtin_ccall("popcount", $gentype, ($gentype,), x) + @device_function mad24(x::$gentype, y::$gentype, z::$gentype) = @builtin_ccall("mad24", $gentype, ($gentype, $gentype, $gentype), x, y, z) @device_function mul24(x::$gentype, y::$gentype) = @builtin_ccall("mul24", $gentype, ($gentype, $gentype), x, y) diff --git a/src/compiler/compilation.jl b/src/compiler/compilation.jl index a8086cfb..52b7607b 100644 --- a/src/compiler/compilation.jl +++ b/src/compiler/compilation.jl @@ -112,6 +112,9 @@ const OPENCL_TO_SPIRV_EXTENSIONS = Dict{String, Vector{String}}( "SPV_EXT_shader_atomic_float_add", "SPV_EXT_shader_atomic_float_min_max", ], + "cl_khr_extended_bit_ops" => [ + "SPV_KHR_bit_instructions", + ], ) """ From 38b583d7ed4ccaaf7aecee3e2c8fbfc25d1fe109 Mon Sep 17 00:00:00 2001 From: SimonDanisch Date: Thu, 29 Jan 2026 15:47:58 +0100 Subject: [PATCH 3/4] add bitreverse --- lib/intrinsics/src/integer.jl | 1 + src/compiler/compilation.jl | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/lib/intrinsics/src/integer.jl b/lib/intrinsics/src/integer.jl index 0d756d53..1a76c38b 100644 --- a/lib/intrinsics/src/integer.jl +++ b/lib/intrinsics/src/integer.jl @@ -36,6 +36,7 @@ for gentype in generic_integer_types @device_function popcount(x::$gentype) = @builtin_ccall("popcount", $gentype, ($gentype,), x) +@device_override Base.bitreverse(x::$gentype) = @builtin_ccall("bit_reverse", $gentype, ($gentype,), x) @device_function mad24(x::$gentype, y::$gentype, z::$gentype) = @builtin_ccall("mad24", $gentype, ($gentype, $gentype, $gentype), x, y, z) @device_function mul24(x::$gentype, y::$gentype) = @builtin_ccall("mul24", $gentype, ($gentype, $gentype), x, y) diff --git a/src/compiler/compilation.jl b/src/compiler/compilation.jl index 52b7607b..d084f3ed 100644 --- a/src/compiler/compilation.jl +++ b/src/compiler/compilation.jl @@ -124,7 +124,7 @@ Query the device's OpenCL extensions and return the corresponding SPIR-V extensi that should be enabled for compilation. """ function spirv_extensions_for_device(dev::cl.Device) - spirv_exts = String[] + spirv_exts = String["SPV_KHR_bit_instructions"] device_exts = dev.extensions for (cl_ext, spirv_ext_list) in OPENCL_TO_SPIRV_EXTENSIONS @@ -167,7 +167,7 @@ end # Auto-detect SPIR-V extensions from device capabilities and merge with # any explicitly requested extensions auto_extensions = spirv_extensions_for_device(dev) - all_extensions = unique(vcat(extensions, auto_extensions)) + all_extensions = unique!(vcat(extensions, auto_extensions)) # create GPUCompiler objects target = SPIRVCompilerTarget(; supports_fp16, supports_fp64, validate=true, From 33228a23580c87777a26fcf53282c0ef95122d76 Mon Sep 17 00:00:00 2001 From: SimonDanisch Date: Fri, 20 Feb 2026 13:13:42 +0100 Subject: [PATCH 4/4] add tests --- test/Project.toml | 2 ++ test/setup.jl | 2 +- test/spirv_extensions.jl | 65 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 68 insertions(+), 1 deletion(-) create mode 100644 test/spirv_extensions.jl diff --git a/test/Project.toml b/test/Project.toml index 57ae7ff9..3dbcc16a 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,5 +1,6 @@ [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +Atomix = "a9b6321e-bd34-4604-b9c9-b65b8de01458" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" @@ -8,6 +9,7 @@ InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +OpenCL = "08131aa3-fb12-5dee-8b74-c09406e224a2" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" REPL = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" diff --git a/test/setup.jl b/test/setup.jl index 90337d36..b6bd9a15 100644 --- a/test/setup.jl +++ b/test/setup.jl @@ -90,7 +90,7 @@ function runtests(f, name, platform_filter) end # some tests require native execution capabilities - requires_il = name in ["atomics", "execution", "intrinsics", "kernelabstractions"] || + requires_il = name in ["atomics", "execution", "intrinsics", "kernelabstractions", "spirv_extensions"] || startswith(name, "gpuarrays/") || startswith(name, "device/") ex = quote diff --git a/test/spirv_extensions.jl b/test/spirv_extensions.jl new file mode 100644 index 00000000..26451536 --- /dev/null +++ b/test/spirv_extensions.jl @@ -0,0 +1,65 @@ +using KernelAbstractions +using Atomix: Atomix + +@testset "spirv_extensions" begin + +@testset "bitreverse KernelAbstractions kernel" begin + @kernel function bitreverse_ka!(out, inp) + i = @index(Global) + @inbounds out[i] = bitreverse(inp[i]) + end + + @testset "$T" for T in [Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64] + N = 64 + inp = CLArray(rand(T, N)) + out = similar(inp) + bitreverse_ka!(OpenCLBackend())(out, inp; ndrange=N) + synchronize(OpenCLBackend()) + + @test Array(out) == bitreverse.(Array(inp)) + end +end + +@testset "atomic float accumulation without manual extensions" begin + # The auto-spirv-extensions feature detects cl_ext_float_atomics and enables + # SPV_EXT_shader_atomic_float_add automatically. Previously this required + # passing extensions=["SPV_EXT_shader_atomic_float_add"] to @opencl manually. + # + # We test with a concurrent accumulation pattern where multiple work-items + # write to the same output locations. Without atomics this would race; + # with Atomix.@atomic (which emits OpAtomicFAddEXT) the result must be exact. + if "cl_ext_float_atomics" in cl.device().extensions + @kernel function atomic_accum_kernel!(out, arr) + i, j = @index(Global, NTuple) + for k in 1:size(out, 1) + Atomix.@atomic out[k, i] += arr[i, j] + end + end + + @testset "$T" for T in [Float32, Float64] + if T == Float64 && !("cl_khr_fp64" in cl.device().extensions) + continue + end + + M, N = 32, 64 + img = zeros(T, M, N) + img[5:15, 5:15] .= one(T) + img[20:30, 20:30] .= T(2) + + cl_img = CLArray(img) + out = KernelAbstractions.zeros(OpenCLBackend(), T, M, N) + atomic_accum_kernel!(OpenCLBackend())(out, cl_img; ndrange=(M, N)) + synchronize(OpenCLBackend()) + + # Each out[k, i] = sum(img[i, :]) — accumulate row i across all columns + out_host = Array(out) + expected = zeros(T, M, N) + for i in 1:M + expected[:, i] .= sum(img[i, :]) + end + @test out_host ≈ expected + end + end +end + +end