Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
20c3e78
simplified & improved cuFFT bindings for fft(), selecting for least e…
RainerHeintzmann Mar 4, 2026
00e7cfb
towards arbitrary dimensions
RainerHeintzmann Mar 11, 2026
34bcc57
Merge branch 'JuliaGPU:master' into master
RainerHeintzmann Mar 11, 2026
cc9a8f1
cleanup. finished tests
RainerHeintzmann Mar 11, 2026
1f12bd0
formatted code
RainerHeintzmann Mar 12, 2026
9b25983
updated Project.toml to current version
RainerHeintzmann Mar 12, 2026
c5d7c97
Merge branch 'JuliaGPU:master' into master
RainerHeintzmann Mar 12, 2026
e75e627
Merge branch 'JuliaGPU:master' into master
RainerHeintzmann Mar 13, 2026
74ba87a
[skip tests][skip benchmarks] reverted LocalPreferences.toml
RainerHeintzmann Mar 13, 2026
08d11c1
Merge branch 'JuliaGPU:master' into master
RainerHeintzmann Mar 13, 2026
b2f46fd
[skip_tests][skip_benchmarks] removed sort on Tuples
RainerHeintzmann Mar 13, 2026
f53bb76
Merge branch 'master' of https://github.com/RainerHeintzmann/CUDA.jl
RainerHeintzmann Mar 13, 2026
e9a9d38
Merge branch 'JuliaGPU:master' into master
RainerHeintzmann Mar 16, 2026
4f0148b
Merge branch 'JuliaGPU:master' into master
RainerHeintzmann Mar 17, 2026
166336a
Merge branch 'JuliaGPU:master' into master
RainerHeintzmann Mar 24, 2026
a70100c
Merge branch 'master' into RainerHeintzmann/master
maleadt Apr 9, 2026
bf0d277
bug fixes according to https://github.com/JuliaGPU/CUDA.jl/pull/3052#…
RainerHeintzmann Apr 10, 2026
9232a41
minor corrections
RainerHeintzmann Apr 15, 2026
c2b2b2e
changed CUDA to CUDACore in test
RainerHeintzmann Apr 15, 2026
337e84e
changed naming from `ensure_raising` to `ensure_increasing`
RainerHeintzmann Apr 15, 2026
31c0271
also chaned `ensure_increasing` naming in the tests
RainerHeintzmann Apr 15, 2026
c101ea5
Merge branch 'master' into master
RainerHeintzmann Apr 15, 2026
dae1fdf
ensure_increasing to not allocate (https://github.com/JuliaGPU/CUDA.j…
RainerHeintzmann Apr 16, 2026
ae48012
cuFFT: style and typo cleanup in cufftXtMakePlanMany rewrite
maleadt Apr 17, 2026
6bb6b44
cuFFT: drop unreachable ensure_increasing(::Number); refine fallback …
maleadt Apr 17, 2026
c3ac4c6
cuFFT: pad inembed[0]/cnembed[0] to satisfy cuFFT's >= n[0] precondition
maleadt Apr 17, 2026
ec2f224
cuFFT: explicit duplicate-dim rejection; remove silent unique()
maleadt Apr 17, 2026
e6bc207
cuFFT: enforce strictly-increasing region in plan_rfft / plan_brfft
maleadt Apr 17, 2026
3afb999
cuFFT: replace circshift! workaround with out-of-place scratch approach
maleadt Apr 17, 2026
0ec5ea8
cuFFT: surface half-precision power-of-2 violations as ArgumentError
maleadt Apr 17, 2026
b028fe8
cuFFT: restore (2,4) test coverage in 4D complex batched FFT
maleadt Apr 17, 2026
04b7aa1
cuFFT: document plan-cache key trade-off
maleadt Apr 17, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
236 changes: 183 additions & 53 deletions lib/cufft/src/fft.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,25 +130,80 @@ function irfft(x::DenseCuArray{<:Union{Real,Integer,Rational}}, d::Integer, regi
irfft(complexfloat(x), d, region)
end

# yields the maximal dimensions of the plan, for plans starting at dim 1 or ending at the size vector,
# this is always the full input size
function plan_max_dims(region, sz)
if (region[1] == 1 && (length(region) <=1 || all(diff(collect(region)) .== 1)))
return length(sz)
else
return region[end]
"""
get_batch_dims(region, sz)

returns the dimensions over which to run internal batching and dimensions used for external (for-loop) batching.
It finds the largest product of consecutive dimensions and uses these as internal batch dimensions.
All other dimensions are external batch dimensions.

internal_batch_dims, external_batch_dims = get_batch_dims(region, sz)

# Parameters:
- `region`: Tuple of dimensions to transform
- `sz`: size of the array to transform. All dimensions not in `region` are considered as batch dimensions.
This size Tuple is only used to determine the best set of consecutive dimensions to be used for internal batching.
"""
function get_batch_dims(region, sz)
internal_batch_dims = ()
external_batch_dims = ()
previous_transform_dim = 0
best_gap_size = 0
# iterate through the transform dimensions and one extra dim beyond the size to cover the external batch dims
for t in (region..., length(sz)+1)
# calculate the product only of consecutively non-transformed sizes
if (t > previous_transform_dim+1)
gap_size = prod(sz[(previous_transform_dim+1):(t-1)])
if (gap_size > best_gap_size)
best_gap_size = gap_size
# the previously best dims were not the best. Add them to the external list.
external_batch_dims = (external_batch_dims..., internal_batch_dims...)
internal_batch_dims = Tuple((previous_transform_dim+1):(t-1))
else
external_batch_dims = (external_batch_dims..., ((previous_transform_dim+1):(t-1))...)
end
end
previous_transform_dim = t
end
return internal_batch_dims, external_batch_dims
end

# retrieves the size to allocate even if the trailing dimensions do no transform
# retrieves the size to allocate even if the external batch dimensions do no transform
get_osz(osz, x) = ntuple((d)->(d>length(osz) ? size(x, d) : osz[d]), ndims(x))

# returns a view of the front part of the dimensions of the array up to md dimensions
function front_view(X, md)
t = ntuple((d)->ifelse(d<=md, Colon(), 1), ndims(X))
@view X[t...]
function ensure_unique(s::NTuple{N, Int}) where N
for i in 1:N, j in i+1:N
s[i] == s[j] && throw(ArgumentError(
"FFT region dimensions must be unique; got $s"))
end
s
end

# rfft/brfft cannot reorder region: cuFFT halves region[1], and AbstractFFTs
# uses `first(region)` to size the output (definitions.jl:343,349). So the
# user must already supply region in strictly increasing order.
function ensure_strictly_increasing(s::NTuple{N, Int}) where N
for i in 1:N-1
s[i] >= s[i+1] && throw(ArgumentError(
"for rfft/brfft, region must be in strictly increasing order " *
"(its first element is the dimension reduced from N to N÷2+1); got $s"))
end
s
end

# Sort on tuples is only implemented as of Julia 1.12, and cuFFT supports at most
# three transform dimensions per plan, so we hand-code the cases here.
ensure_increasing(s::NTuple{1, Int}) = s
ensure_increasing(s::NTuple{2, Int}) = s[1] > s[2] ? (s[2], s[1]) : s
function ensure_increasing(s::NTuple{3, Int})
s[1] > s[2] && (s = (s[2], s[1], s[3]))
s[2] > s[3] && (s = (s[1], s[3], s[2]))
s[1] > s[2] && (s = (s[2], s[1], s[3]))
s
end
function ensure_increasing(s::NTuple{N, Int}) where N
throw(ArgumentError("cuFFT supports at most 3 transform dimensions per plan; got $N"))
end
# region is an iterable subset of dimensions
# spec. an integer, range, tuple, or array

Expand All @@ -169,21 +224,19 @@ end
function plan_fft!(X::DenseCuArray{T,N}, region::NTuple{R,Int}) where {T<:cufftComplexes,N,R}
K = CUFFT_FORWARD
inplace = true
region = ensure_increasing(ensure_unique(region))

md = plan_max_dims(region, size(X))
sizex = size(X)[1:md]
handle = cufftGetPlan(T, T, sizex, region)
handle = cufftGetPlan(T, T, size(X), region)

CuFFTPlan{T,T,K,inplace,N,R,Nothing}(handle, X, size(X), region, nothing)
end

function plan_bfft!(X::DenseCuArray{T,N}, region::NTuple{R,Int}) where {T<:cufftComplexes,N,R}
K = CUFFT_INVERSE
inplace = true
region = ensure_increasing(ensure_unique(region))

md = plan_max_dims(region, size(X))
sizex = size(X)[1:md]
handle = cufftGetPlan(T, T, sizex, region)
handle = cufftGetPlan(T, T, size(X), region)

CuFFTPlan{T,T,K,inplace,N,R,Nothing}(handle, X, size(X), region, nothing)
end
Expand All @@ -192,21 +245,19 @@ end
function plan_fft(X::DenseCuArray{T,N}, region::NTuple{R,Int}) where {T<:cufftComplexes,N,R}
K = CUFFT_FORWARD
inplace = false
region = ensure_increasing(ensure_unique(region))

md = plan_max_dims(region,size(X))
sizex = size(X)[1:md]
handle = cufftGetPlan(T, T, sizex, region)
handle = cufftGetPlan(T, T, size(X), region)

CuFFTPlan{T,T,K,inplace,N,R,Nothing}(handle, X, size(X), region, nothing)
end

function plan_bfft(X::DenseCuArray{T,N}, region::NTuple{R,Int}) where {T<:cufftComplexes,N,R}
K = CUFFT_INVERSE
inplace = false
region = ensure_increasing(ensure_unique(region))

md = plan_max_dims(region,size(X))
sizex = size(X)[1:md]
handle = cufftGetPlan(T, T, sizex, region)
handle = cufftGetPlan(T, T, size(X), region)

CuFFTPlan{T,T,K,inplace,N,R,Nothing}(handle, size(X), size(X), region, nothing)
end
Expand All @@ -221,11 +272,9 @@ end
function plan_rfft(X::DenseCuArray{T,N}, region::NTuple{R,Int}) where {T<:cufftReals,N,R}
K = CUFFT_FORWARD
inplace = false
region = ensure_strictly_increasing(region)

md = plan_max_dims(region,size(X))
sizex = size(X)[1:md]

handle = cufftGetPlan(complex(T), T, sizex, region)
handle = cufftGetPlan(complex(T), T, size(X), region)

xdims = size(X)
ydims = Base.setindex(xdims, div(xdims[region[1]], 2) + 1, region[1])
Expand All @@ -248,6 +297,7 @@ end
function plan_brfft(X::DenseCuArray{T,N}, d::Integer, region::NTuple{R,Int}) where {T<:cufftComplexes,N,R}
K = CUFFT_INVERSE
inplace = false
region = ensure_strictly_increasing(region)

xdims = size(X)
ydims = Base.setindex(xdims, d, region[1])
Expand All @@ -265,18 +315,14 @@ end

function plan_inv(p::CuFFTPlan{T,S,CUFFT_INVERSE,inplace,N,R,B}
) where {T<:cufftNumber,S<:cufftNumber,inplace,N,R,B}
md_osz = plan_max_dims(p.region, p.output_size)
sz_X = p.output_size[1:md_osz]
handle = cufftGetPlan(S, T, sz_X, p.region)
handle = cufftGetPlan(S, T, p.output_size, p.region)
ScaledPlan(CuFFTPlan{S,T,CUFFT_FORWARD,inplace,N,R,B}(handle, p.output_size, p.input_size, p.region, p.buffer),
normalization(real(T), p.output_size, p.region))
end

function plan_inv(p::CuFFTPlan{T,S,CUFFT_FORWARD,inplace,N,R,B}
) where {T<:cufftNumber,S<:cufftNumber,inplace,N,R,B}
md_isz = plan_max_dims(p.region, p.input_size)
sz_Y = p.input_size[1:md_isz]
handle = cufftGetPlan(S, T, sz_Y, p.region)
handle = cufftGetPlan(S, T, p.input_size, p.region)
ScaledPlan(CuFFTPlan{S,T,CUFFT_INVERSE,inplace,N,R,B}(handle, p.output_size, p.input_size, p.region, p.buffer),
normalization(real(S), p.input_size, p.region))
end
Expand Down Expand Up @@ -313,24 +359,108 @@ function unsafe_execute!(plan::CuFFTPlan{T,S,K,inplace}, x::DenseCuArray{S}, y::
cufftXtExec(plan, x, y, K)
end

# a version of unsafe_execute which applies the plan to each element of trailing dimensions not covered by the plan.
# Note that for plans, with trailing non-transform dimensions views are created for each of such elements.
# Such views each have lower dimensions and are then transformed by the lower dimension low-level Cuda plan.
function unsafe_execute_trailing!(p, x, y)
N = plan_max_dims(p.region, p.output_size)
M = ndims(x)
d = p.region[end]
if M == N
unsafe_execute!(p,x,y)
else
front_ids = ntuple((dd)->Colon(), d)
for c in CartesianIndices(size(x)[d+1:end])
ids = ntuple((dd)->c[dd], M-N)
vx = @view x[front_ids..., ids...]
vy = @view y[front_ids..., ids...]
unsafe_execute!(p,vx,vy)
# 0-based footprint of one cuFFT execution, in elements: max linear offset
# touched + 1. cuFFT covers all combinations of (region ∪ internal_batch_dims);
# each external batch dim is fixed at the current index.
function plan_footprint(sizes::Dims, region, internal_batch_dims)
covered = (region..., internal_batch_dims...)
isempty(covered) && return 1
max_off = 0
for d in covered
stride_d = prod(sizes[1:d-1])
max_off += (sizes[d] - 1) * stride_d
end
return max_off + 1
end

# Out-of-place left-rotation by `shift`: dest[i] = src[mod(i - 1 + shift, n) + 1]
# for i in 1..n. Does not mutate `src`.
function shift_copy!(dest::DenseCuArray, src::DenseCuArray, shift::Integer)
n = length(src)
shift = mod(shift, n)
shift == 0 && return unsafe_copyto!(dest, 1, src, 1, n)
unsafe_copyto!(dest, 1, src, shift + 1, n - shift)
unsafe_copyto!(dest, n - shift + 1, src, 1, shift)
return dest
end

# A version of unsafe_execute! that handles external batch dims (dims outside
# the cuFFT plan's internal batching) by issuing one cuFFT call per external
# index. cuFFT's R2C/C2R APIs require each call's base address to be
# Complex-aligned; in 1-based Julia indices that means the batch's linear
# start (`bs`) must be odd when its side of the transform has Real eltype.
#
# Two strategies are used, picked by which side is Real:
#
# * R2C (input Real, output Complex): rotate x into a fresh aligned
# `scratch_x` once and read misaligned batches from it. The user's input
# is never mutated and the extra cost is O(length(x)) regardless of how
# many batches are misaligned.
#
# * C2R (input Complex, output Real): per-misaligned-batch
# read-modify-write through a small `scratch_y` of size `foot_y`. Other
# external batches share the same footprint range in y, so we seed
# scratch_y with y's current contents before each cuFFT call and copy
# back after, preserving their writes.
function unsafe_execute_external_batches!(p::CuFFTPlan{T,S,K,inplace}, x, y) where {T,S,K,inplace}
region = p.region
internal_dims, external_dims = get_batch_dims(region, p.output_size)
if isempty(external_dims)
unsafe_execute!(p, x, y)
return
end

prefix_prod_x = (1, cumprod(size(x))...)
prefix_prod_y = (1, cumprod(size(y))...)
ext_stride_x = map(d -> prefix_prod_x[d], external_dims)
ext_stride_y = map(d -> prefix_prod_y[d], external_dims)
ext_size = map(d -> size(x, d), external_dims)
ci = CartesianIndices(ext_size)
foot_x = plan_footprint(size(x), region, internal_dims)
foot_y = plan_footprint(size(y), region, internal_dims)

# R2C side: find the first misaligned external batch (if any) and pre-rotate
# x into scratch_x once. The same shift aligns every other misaligned batch
# in the same stride orbit.
to_skip_x = 0
scratch_x = nothing
if S <: Real
for c in ci
bs = sum(ext_stride_x .* (Tuple(c) .- 1)) + 1
if iseven(bs); to_skip_x = bs - 1; break; end
end
if to_skip_x > 0
scratch_x = shift_copy!(CuArray{S}(undef, length(x)), x, to_skip_x)
end
end

# C2R side: per-misaligned-batch scratch, allocated lazily.
scratch_y = nothing

for c in ci
bs_x = sum(ext_stride_x .* (Tuple(c) .- 1)) + 1
bs_y = sum(ext_stride_y .* (Tuple(c) .- 1)) + 1
misaligned_x = S <: Real && iseven(bs_x)
misaligned_y = T <: Real && iseven(bs_y)

vx = misaligned_x ?
(@view scratch_x[bs_x - to_skip_x : end]) :
(@view x[bs_x : end])
vy = if misaligned_y
scratch_y === nothing && (scratch_y = CuArray{T}(undef, foot_y))
unsafe_copyto!(scratch_y, 1, y, bs_y, foot_y)
@view scratch_y[1:foot_y]
else
@view y[bs_y : end]
end

unsafe_execute!(p, vx, vy)

if misaligned_y
unsafe_copyto!(y, bs_y, scratch_y, 1, foot_y)
end
end
return
end

## high-level integrations
Expand All @@ -346,13 +476,13 @@ function LinearAlgebra.mul!(y::DenseCuArray{T}, p::CuFFTPlan{T,S,K,inplace}, x::
else
z = x
end
unsafe_execute_trailing!(p, z, y)
unsafe_execute_external_batches!(p, z, y)
y
end

function Base.:(*)(p::CuFFTPlan{T,S,K,true}, x::DenseCuArray{S}) where {T,S,K}
assert_applicable(p, x)
unsafe_execute_trailing!(p, x, x)
unsafe_execute_external_batches!(p, x, x)
x
end

Expand All @@ -372,6 +502,6 @@ function Base.:(*)(p::CuFFTPlan{T,S,K,false}, x::DenseCuArray{S1,M}) where {T,S,
end
assert_applicable(p, z)
y = CuArray{T,M}(undef, p.output_size)
unsafe_execute_trailing!(p, z, y)
unsafe_execute_external_batches!(p, z, y)
y
end
Loading