From 7fbbe55932a6c6777ae6267ea1a1581a61021890 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Mon, 15 Dec 2025 23:50:38 +0100 Subject: [PATCH 1/3] support emitting a generated function --- src/KernelAbstractions.jl | 11 ++++++++--- src/macros.jl | 8 +++++++- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/src/KernelAbstractions.jl b/src/KernelAbstractions.jl index 8ef7c071d..e643a6238 100644 --- a/src/KernelAbstractions.jl +++ b/src/KernelAbstractions.jl @@ -50,7 +50,7 @@ synchronize(backend) ``` """ macro kernel(expr) - return __kernel(expr, #=generate_cpu=# true, #=force_inbounds=# false, #=unsafe_indices=# false) + return __kernel(expr, #=generate_cpu=# true, #=force_inbounds=# false, #=unsafe_indices=# false, #=generated=# false) end """ @@ -69,11 +69,12 @@ This allows for two different configurations: """ macro kernel(ex...) if length(ex) == 1 - return __kernel(ex[1], true, false, false) + return __kernel(ex[1], true, false, false, false) else generate_cpu = true unsafe_indices = false force_inbounds = false + generated = false for i in 1:(length(ex) - 1) if ex[i] isa Expr && ex[i].head == :(=) && ex[i].args[1] == :cpu && ex[i].args[2] isa Bool @@ -84,17 +85,21 @@ macro kernel(ex...) elseif ex[i] isa Expr && ex[i].head == :(=) && ex[i].args[1] == :unsafe_indices && ex[i].args[2] isa Bool unsafe_indices = ex[i].args[2] + elseif ex[i] isa Expr && ex[i].head == :(=) && + ex[i].args[1] == :generated && ex[i].args[2] isa Bool + generated = ex[i].args[2] else error( "Configuration should be of form:\n" * "* `cpu=false`\n" * "* `inbounds=true`\n" * "* `unsafe_indices=true`\n" * + "* `generated=true`\n" * "got `", ex[i], "`", ) end end - return __kernel(ex[end], generate_cpu, force_inbounds, unsafe_indices) + return __kernel(ex[end], generate_cpu, force_inbounds, unsafe_indices, generated) end end diff --git a/src/macros.jl b/src/macros.jl index 696008727..4e8aaeb28 100644 --- a/src/macros.jl +++ b/src/macros.jl @@ -10,7 +10,7 @@ function find_return(stmt) end # XXX: Proper errors -function __kernel(expr, generate_cpu = true, force_inbounds = false, unsafe_indices = false) +function __kernel(expr, generate_cpu = true, force_inbounds = false, unsafe_indices = false, generated = false) def = splitdef(expr) name = def[:name] args = def[:args] @@ -41,12 +41,18 @@ function __kernel(expr, generate_cpu = true, force_inbounds = false, unsafe_indi def_cpu = deepcopy(def) def_cpu[:name] = cpu_name transform_cpu!(def_cpu, constargs, force_inbounds) + if generated + def_cpu[:body] = Expr(:if, Expr(:generated), Expr(:copyast, QuoteNode(def_cpu[:body])), Expr(:meta, :generated_only)) + end cpu_function = combinedef(def_cpu) end def_gpu = deepcopy(def) def_gpu[:name] = gpu_name = Symbol(:gpu_, name) transform_gpu!(def_gpu, constargs, force_inbounds, unsafe_indices) + if generated + def_gpu[:body] = Expr(:if, Expr(:generated), Expr(:copyast, QuoteNode(def_gpu[:body])), Expr(:meta, :generated_only)) + end gpu_function = combinedef(def_gpu) # create constructor functions From 9195d94e9bb10ae5a992363d3fb9de162e93f13a Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Tue, 16 Dec 2025 10:55:28 +0100 Subject: [PATCH 2/3] with proper interpolation --- src/KernelAbstractions.jl | 6 +++--- src/macros.jl | 10 +++++++--- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/KernelAbstractions.jl b/src/KernelAbstractions.jl index e643a6238..9469c21a3 100644 --- a/src/KernelAbstractions.jl +++ b/src/KernelAbstractions.jl @@ -50,7 +50,7 @@ synchronize(backend) ``` """ macro kernel(expr) - return __kernel(expr, #=generate_cpu=# true, #=force_inbounds=# false, #=unsafe_indices=# false, #=generated=# false) + return __kernel(__module__, expr, #=generate_cpu=# true, #=force_inbounds=# false, #=unsafe_indices=# false, #=generated=# false) end """ @@ -69,7 +69,7 @@ This allows for two different configurations: """ macro kernel(ex...) if length(ex) == 1 - return __kernel(ex[1], true, false, false, false) + return __kernel(__module__, ex[1], true, false, false, false) else generate_cpu = true unsafe_indices = false @@ -99,7 +99,7 @@ macro kernel(ex...) ) end end - return __kernel(ex[end], generate_cpu, force_inbounds, unsafe_indices, generated) + return __kernel(__module__, ex[end], generate_cpu, force_inbounds, unsafe_indices, generated) end end diff --git a/src/macros.jl b/src/macros.jl index 4e8aaeb28..ddf67d2dd 100644 --- a/src/macros.jl +++ b/src/macros.jl @@ -10,7 +10,7 @@ function find_return(stmt) end # XXX: Proper errors -function __kernel(expr, generate_cpu = true, force_inbounds = false, unsafe_indices = false, generated = false) +function __kernel(__module__, expr, generate_cpu = true, force_inbounds = false, unsafe_indices = false, generated = false) def = splitdef(expr) name = def[:name] args = def[:args] @@ -42,7 +42,9 @@ function __kernel(expr, generate_cpu = true, force_inbounds = false, unsafe_indi def_cpu[:name] = cpu_name transform_cpu!(def_cpu, constargs, force_inbounds) if generated - def_cpu[:body] = Expr(:if, Expr(:generated), Expr(:copyast, QuoteNode(def_cpu[:body])), Expr(:meta, :generated_only)) + # Use macroexpand to perform the annoying work of interpolating `$` exprs + body = macroexpand(__module__, Expr(:quote, def_cpu[:body]), recursive = false) + def_cpu[:body] = Expr(:if, Expr(:generated), body, Expr(:meta, :generated_only)) end cpu_function = combinedef(def_cpu) end @@ -51,7 +53,9 @@ function __kernel(expr, generate_cpu = true, force_inbounds = false, unsafe_indi def_gpu[:name] = gpu_name = Symbol(:gpu_, name) transform_gpu!(def_gpu, constargs, force_inbounds, unsafe_indices) if generated - def_gpu[:body] = Expr(:if, Expr(:generated), Expr(:copyast, QuoteNode(def_gpu[:body])), Expr(:meta, :generated_only)) + # Use macroexpand to perform the annoying work of interpolating `$` exprs + body = macroexpand(__module__, Expr(:quote, def_gpu[:body]), recursive = false) + def_gpu[:body] = Expr(:if, Expr(:generated), body, Expr(:meta, :generated_only)) end gpu_function = combinedef(def_gpu) From d4ffa3b2ec53fab879706cdab0dacafa7d8535b3 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Fri, 30 Jan 2026 21:52:31 +0100 Subject: [PATCH 3/3] add missing test --- test/reflection.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/reflection.jl b/test/reflection.jl index 6ce46b2b1..9952e098e 100644 --- a/test/reflection.jl +++ b/test/reflection.jl @@ -15,6 +15,11 @@ end A[I] = i + C[I] end +@kernel generated = true function f(::Val{N}) where {N} + KernelAbstractions.Extras.@unroll $N for i in 1:10 + end +end + function test_typed_kernel_dynamic(backend, backend_str, ArrayT) A = ArrayT(ones(Float32, 1024, 1024)) kernel = mul2(backend()) @@ -102,5 +107,6 @@ function reflection_testsuite(backend, backend_str, ArrayT) test_typed_kernel_static(backend, backend_str, ArrayT) test_typed_kernel_no_optimize(backend, backend_str, ArrayT) test_expr_kernel(backend, backend_str, ArrayT) + test_generated_kernel(backend, backend_str, ArrayT) return end