Skip to content

support emitting a generated function#667

Open
vchuravy wants to merge 3 commits intorelease-0.9from
vc/generated
Open

support emitting a generated function#667
vchuravy wants to merge 3 commits intorelease-0.9from
vc/generated

Conversation

@vchuravy
Copy link
Member

@vchuravy vchuravy commented Dec 15, 2025

Motivated by #665

Proposed syntax

@kernel generated=true function f(::Val{N}) where N
           KernelAbstractions.Extras.@unroll $N for i in 1:10
           end
       end

Sadly this doesn't quite work yet, since I need to handle the $N correctly

Currently:

    function cpu_f(__ctx__, ::Val{N}; ) where N
        if $(Expr(:generated))
            $(Expr(:copyast, :($(QuoteNode(:(let
      begin
          var"##N#249" = length((KernelAbstractions.__workitems_iterspace)(__ctx__))
          begin
              #= /home/vchuravy/src/KernelAbstractions/src/macros.jl:317 =#
              for var"##I#248" = (KernelAbstractions.__workitems_iterspace)(__ctx__)
                  #= /home/vchuravy/src/KernelAbstractions/src/macros.jl:318 =#
                  (KernelAbstractions.__validindex)(__ctx__, var"##I#248") || continue
                  #= /home/vchuravy/src/KernelAbstractions/src/macros.jl:319 =#
                  #= /home/vchuravy/src/KernelAbstractions/src/macros.jl:320 =#
                  #= REPL[19]:2 =# KernelAbstractions.Extra.@unroll $(Expr(:$, :N)) for i = 1:10
                          #= REPL[19]:2 =#
                          #= REPL[19]:3 =#
                      end
                  #= /home/vchuravy/src/KernelAbstractions/src/macros.jl:321 =#
              end
          end
      end
      return nothing
  end))))))
        else
            $(Expr(:meta, :generated_only))
        end
    end

Whereas

julia> @macroexpand @generated function(::Val{N}) where N
           quote
               KernelAbstractions.Extras.@unroll $N for i in 1:10
               end
           end
       end
:(function (::Val{N},) where N
      #= REPL[18]:1 =#
      if $(Expr(:generated))
          #= REPL[18]:1 =#
          #= REPL[18]:2 =#
          Core._expr(:block, $(QuoteNode(:(#= REPL[18]:3 =#))), Core._expr(:macrocall, $(Expr(:copyast, :($(QuoteNode(:(KernelAbstractions.Extras.var"@unroll")))))), $(QuoteNode(:(#= REPL[18]:3 =#))), N, $(Expr(:copyast, :($(QuoteNode(:(for i = 1:10
      #= REPL[18]:3 =#
      #= REPL[18]:4 =#
  end))))))))
      else
          $(Expr(:meta, :generated_only))
          return
      end
  end)

Note that the QuoteNode got broken into smaller pieces with the interpolated variable being left alone and the rest being passed to _expr and QuoteNode

@github-actions
Copy link
Contributor

github-actions bot commented Dec 15, 2025

Benchmark Results

main d4ffa3b... main / d4ffa3b...
saxpy/default/Float32/1024 0.0728 ± 0.029 ms 0.627 ± 0.0079 μs 116 ± 46
saxpy/default/Float32/1048576 0.454 ± 0.022 ms 0.274 ± 0.02 ms 1.65 ± 0.14
saxpy/default/Float32/16384 0.0702 ± 0.03 ms 2.84 ± 0.31 μs 24.7 ± 11
saxpy/default/Float32/2048 0.0737 ± 0.03 ms 0.76 ± 0.059 μs 96.9 ± 40
saxpy/default/Float32/256 0.0721 ± 0.028 ms 0.565 ± 0.0053 μs 128 ± 49
saxpy/default/Float32/262144 0.167 ± 0.029 ms 0.0676 ± 0.0046 ms 2.47 ± 0.47
saxpy/default/Float32/32768 0.0713 ± 0.03 ms 5.49 ± 0.59 μs 13 ± 5.7
saxpy/default/Float32/4096 0.0755 ± 0.031 ms 1.13 ± 0.14 μs 66.6 ± 28
saxpy/default/Float32/512 0.0711 ± 0.029 ms 0.598 ± 0.0056 μs 119 ± 49
saxpy/default/Float32/64 0.0721 ± 0.029 ms 0.555 ± 0.0052 μs 130 ± 51
saxpy/default/Float32/65536 0.0938 ± 0.03 ms 13.8 ± 1.1 μs 6.77 ± 2.2
saxpy/default/Float64/1024 0.0689 ± 0.03 ms 0.746 ± 0.046 μs 92.3 ± 41
saxpy/default/Float64/1048576 0.615 ± 0.099 ms 0.595 ± 0.052 ms 1.03 ± 0.19
saxpy/default/Float64/16384 0.0773 ± 0.03 ms 5.45 ± 0.5 μs 14.2 ± 5.6
saxpy/default/Float64/2048 0.0697 ± 0.029 ms 1.17 ± 0.15 μs 59.7 ± 26
saxpy/default/Float64/256 0.0637 ± 0.031 ms 0.58 ± 0.0071 μs 110 ± 53
saxpy/default/Float64/262144 0.196 ± 0.034 ms 0.135 ± 0.013 ms 1.45 ± 0.29
saxpy/default/Float64/32768 0.0844 ± 0.029 ms 13.8 ± 1.1 μs 6.1 ± 2.1
saxpy/default/Float64/4096 0.071 ± 0.03 ms 1.73 ± 0.2 μs 41 ± 18
saxpy/default/Float64/512 0.0688 ± 0.03 ms 0.634 ± 0.0093 μs 108 ± 47
saxpy/default/Float64/64 0.0653 ± 0.031 ms 0.554 ± 0.0069 μs 118 ± 57
saxpy/default/Float64/65536 0.104 ± 0.03 ms 0.0334 ± 0.0024 ms 3.11 ± 0.92
saxpy/static workgroup=(1024,)/Float32/1024 0.0706 ± 0.029 ms 2.06 ± 0.026 μs 34.3 ± 14
saxpy/static workgroup=(1024,)/Float32/1048576 0.457 ± 0.023 ms 0.281 ± 0.022 ms 1.63 ± 0.15
saxpy/static workgroup=(1024,)/Float32/16384 0.0679 ± 0.029 ms 4.33 ± 0.34 μs 15.7 ± 6.8
saxpy/static workgroup=(1024,)/Float32/2048 0.0659 ± 0.029 ms 2.22 ± 0.083 μs 29.7 ± 13
saxpy/static workgroup=(1024,)/Float32/256 0.0696 ± 0.029 ms 2.55 ± 0.022 μs 27.3 ± 11
saxpy/static workgroup=(1024,)/Float32/262144 0.165 ± 0.028 ms 0.0693 ± 0.013 ms 2.39 ± 0.6
saxpy/static workgroup=(1024,)/Float32/32768 0.0759 ± 0.03 ms 7.67 ± 0.64 μs 9.89 ± 4.1
saxpy/static workgroup=(1024,)/Float32/4096 0.0725 ± 0.031 ms 2.56 ± 0.13 μs 28.3 ± 12
saxpy/static workgroup=(1024,)/Float32/512 0.0704 ± 0.029 ms 2.7 ± 0.027 μs 26.1 ± 11
saxpy/static workgroup=(1024,)/Float32/64 0.0683 ± 0.03 ms 2.45 ± 0.021 μs 27.9 ± 12
saxpy/static workgroup=(1024,)/Float32/65536 0.0897 ± 0.031 ms 17 ± 2.3 μs 5.29 ± 2
saxpy/static workgroup=(1024,)/Float64/1024 0.0681 ± 0.029 ms 2.23 ± 0.075 μs 30.6 ± 13
saxpy/static workgroup=(1024,)/Float64/1048576 0.576 ± 0.077 ms 0.649 ± 0.11 ms 0.887 ± 0.19
saxpy/static workgroup=(1024,)/Float64/16384 0.0754 ± 0.029 ms 7.64 ± 0.66 μs 9.87 ± 3.9
saxpy/static workgroup=(1024,)/Float64/2048 0.0705 ± 0.03 ms 2.56 ± 0.13 μs 27.6 ± 12
saxpy/static workgroup=(1024,)/Float64/256 0.0656 ± 0.029 ms 2.55 ± 0.026 μs 25.7 ± 11
saxpy/static workgroup=(1024,)/Float64/262144 0.197 ± 0.031 ms 0.144 ± 0.025 ms 1.37 ± 0.32
saxpy/static workgroup=(1024,)/Float64/32768 0.0818 ± 0.028 ms 17.1 ± 1.8 μs 4.8 ± 1.7
saxpy/static workgroup=(1024,)/Float64/4096 0.071 ± 0.03 ms 3.11 ± 0.21 μs 22.8 ± 9.8
saxpy/static workgroup=(1024,)/Float64/512 0.0695 ± 0.029 ms 2.71 ± 0.031 μs 25.6 ± 11
saxpy/static workgroup=(1024,)/Float64/64 0.0682 ± 0.03 ms 2.43 ± 0.021 μs 28.1 ± 12
saxpy/static workgroup=(1024,)/Float64/65536 0.102 ± 0.031 ms 0.0362 ± 0.0054 ms 2.81 ± 0.95
time_to_load 1.01 ± 0.036 s 0.298 ± 0.0018 s 3.4 ± 0.12

Benchmark Plots

A plot of the benchmark results have been uploaded as an artifact to the workflow run for this PR.
Go to "Actions"->"Benchmark a pull request"->[the most recent run]->"Artifacts" (at the bottom).

@codecov
Copy link

codecov bot commented Dec 15, 2025

Codecov Report

❌ Patch coverage is 46.15385% with 7 lines in your changes missing coverage. Please review.
✅ Project coverage is 71.36%. Comparing base (1f84b17) to head (9195d94).

Files with missing lines Patch % Lines
src/macros.jl 42.85% 4 Missing ⚠️
src/KernelAbstractions.jl 50.00% 3 Missing ⚠️
Additional details and impacted files
@@               Coverage Diff               @@
##           release-0.9     #667      +/-   ##
===============================================
- Coverage        71.85%   71.36%   -0.49%     
===============================================
  Files               14       14              
  Lines              906      915       +9     
===============================================
+ Hits               651      653       +2     
- Misses             255      262       +7     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@vchuravy
Copy link
Member Author

julia> @kernel generated=false function g(::Val{N}) where N
                  KernelAbstractions.Extras.@unroll $N for i in 1:10
                  end
              end
ERROR: LoadError: Syntax error: `@unroll N expr` needs a constant integer N
Stacktrace:
 [1] error(s::String)
   @ Base ./error.jl:35
 [2] var"@unroll"(__source__::LineNumberNode, __module__::Module, N::Any, expr::Any)
   @ KernelAbstractions.Extras.LoopInfo ~/src/KernelAbstractions/src/extras/loopinfo.jl:60
in expression starting at REPL[5]:2
julia> @kernel generated=true function f(::Val{N}) where N
                  KernelAbstractions.Extras.@unroll $N for i in 1:10
                  end
              end
 @ka_code_llvm optimize=false raw=true dump_module = true f(CPU())(Val(17), ndrange=1)
; ...
L33:                                              ; preds = %L29
; │ @ none within `macro expansion` @ /home/vchuravy/src/KernelAbstractions/src/macros.jl:321
; │┌ @ multidimensional.jl:417 within `iterate`
    br label %L22, !dbg !87, !llvm.loop !100
; ...
!100 = distinct !{!100, !101}
!101 = !{!"llvm.loop.unroll.count", i64 17}

Seems to work!

@vchuravy vchuravy marked this pull request as ready for review December 16, 2025 10:00
@vchuravy vchuravy requested a review from gbaraldi December 16, 2025 10:02
@vchuravy vchuravy added enhancement New feature or request needs test labels Dec 16, 2025
@vchuravy
Copy link
Member Author

vchuravy commented Dec 21, 2025

Homework for myself:

Write the kernel below with @nif


f(x::Integer) = 1
f(x::AbstractFloat) = 2

@kernel function _kern!(A::AbstractArray, B::AbstractArray)
    idx = @index(Global, Cartesian) 
    tpl = B[idx]
    k = tpl[1] > 0 ? 2 : 1
    A[idx] = ntuple(Val(length(tpl))) do j
        if j == k
            f(tpl[j])
        end
    end[k]
    nothing
end

wtf(N=64) = let
    A = cu(zeros(N,N))
    B = cu(tuple.(randn(N,N), ones(Int,N,N)))
    _kern!(get_backend(A))(A, B; ndrange=size(A))
end

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request needs test

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant