diff --git a/src/KernelAbstractions.jl b/src/KernelAbstractions.jl index 8ef7c071d..9469c21a3 100644 --- a/src/KernelAbstractions.jl +++ b/src/KernelAbstractions.jl @@ -50,7 +50,7 @@ synchronize(backend) ``` """ macro kernel(expr) - return __kernel(expr, #=generate_cpu=# true, #=force_inbounds=# false, #=unsafe_indices=# false) + return __kernel(__module__, expr, #=generate_cpu=# true, #=force_inbounds=# false, #=unsafe_indices=# false, #=generated=# false) end """ @@ -69,11 +69,12 @@ This allows for two different configurations: """ macro kernel(ex...) if length(ex) == 1 - return __kernel(ex[1], true, false, false) + return __kernel(__module__, ex[1], true, false, false, false) else generate_cpu = true unsafe_indices = false force_inbounds = false + generated = false for i in 1:(length(ex) - 1) if ex[i] isa Expr && ex[i].head == :(=) && ex[i].args[1] == :cpu && ex[i].args[2] isa Bool @@ -84,17 +85,21 @@ macro kernel(ex...) elseif ex[i] isa Expr && ex[i].head == :(=) && ex[i].args[1] == :unsafe_indices && ex[i].args[2] isa Bool unsafe_indices = ex[i].args[2] + elseif ex[i] isa Expr && ex[i].head == :(=) && + ex[i].args[1] == :generated && ex[i].args[2] isa Bool + generated = ex[i].args[2] else error( "Configuration should be of form:\n" * "* `cpu=false`\n" * "* `inbounds=true`\n" * "* `unsafe_indices=true`\n" * + "* `generated=true`\n" * "got `", ex[i], "`", ) end end - return __kernel(ex[end], generate_cpu, force_inbounds, unsafe_indices) + return __kernel(__module__, ex[end], generate_cpu, force_inbounds, unsafe_indices, generated) end end diff --git a/src/macros.jl b/src/macros.jl index 696008727..ddf67d2dd 100644 --- a/src/macros.jl +++ b/src/macros.jl @@ -10,7 +10,7 @@ function find_return(stmt) end # XXX: Proper errors -function __kernel(expr, generate_cpu = true, force_inbounds = false, unsafe_indices = false) +function __kernel(__module__, expr, generate_cpu = true, force_inbounds = false, unsafe_indices = false, generated = false) def = splitdef(expr) name = def[:name] args = def[:args] @@ -41,12 +41,22 @@ function __kernel(expr, generate_cpu = true, force_inbounds = false, unsafe_indi def_cpu = deepcopy(def) def_cpu[:name] = cpu_name transform_cpu!(def_cpu, constargs, force_inbounds) + if generated + # Use macroexpand to perform the annoying work of interpolating `$` exprs + body = macroexpand(__module__, Expr(:quote, def_cpu[:body]), recursive = false) + def_cpu[:body] = Expr(:if, Expr(:generated), body, Expr(:meta, :generated_only)) + end cpu_function = combinedef(def_cpu) end def_gpu = deepcopy(def) def_gpu[:name] = gpu_name = Symbol(:gpu_, name) transform_gpu!(def_gpu, constargs, force_inbounds, unsafe_indices) + if generated + # Use macroexpand to perform the annoying work of interpolating `$` exprs + body = macroexpand(__module__, Expr(:quote, def_gpu[:body]), recursive = false) + def_gpu[:body] = Expr(:if, Expr(:generated), body, Expr(:meta, :generated_only)) + end gpu_function = combinedef(def_gpu) # create constructor functions diff --git a/test/reflection.jl b/test/reflection.jl index 6ce46b2b1..9952e098e 100644 --- a/test/reflection.jl +++ b/test/reflection.jl @@ -15,6 +15,11 @@ end A[I] = i + C[I] end +@kernel generated = true function f(::Val{N}) where {N} + KernelAbstractions.Extras.@unroll $N for i in 1:10 + end +end + function test_typed_kernel_dynamic(backend, backend_str, ArrayT) A = ArrayT(ones(Float32, 1024, 1024)) kernel = mul2(backend()) @@ -102,5 +107,6 @@ function reflection_testsuite(backend, backend_str, ArrayT) test_typed_kernel_static(backend, backend_str, ArrayT) test_typed_kernel_no_optimize(backend, backend_str, ArrayT) test_expr_kernel(backend, backend_str, ArrayT) + test_generated_kernel(backend, backend_str, ArrayT) return end