Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/NeuralOperators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using Random: Random, AbstractRNG
using Lux: Lux, Chain, Dense, Conv, Parallel, NoOpLayer, WrappedFunction, Scale
using LuxCore: LuxCore, AbstractLuxLayer, AbstractLuxWrapperLayer
using LuxLib: fast_activation!!
using NNlib: NNlib, batched_mul, pad_constant, gelu
using NNlib: NNlib, batched_mul, gelu
using WeightInitializers: glorot_uniform

include("utils.jl")
Expand Down
7 changes: 3 additions & 4 deletions src/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,9 @@ function (conv::OperatorConv)(x::AbstractArray{T, N}, ps, st) where {T, N}
x_tr = truncate_modes(conv.tform, x_t)
x_p = apply_pattern(x_tr, ps.weight)

pad_dims = size(x_t)[1:(end - 2)] .- size(x_p)[1:(end - 2)]
x_padded = pad_constant(
x_p, expand_pad_dims(pad_dims), false; dims = ntuple(identity, ndims(x_p) - 2)
)
# Use type-stable zero-padding to restore to FFT size
target_sizes = ntuple(i -> size(x_t, i), Val(N - 2))
x_padded = pad_zeros_spatial(x_p, target_sizes)
out = inverse(conv.tform, x_padded, x)

return out, st
Expand Down
33 changes: 24 additions & 9 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,30 @@ function expand_pad_dims(pad_dims::Dims{N}) where {N}
return ntuple(i -> isodd(i) ? 0 : pad_dims[i ÷ 2], 2N)
end

function meshgrid(args::AbstractVector...)
return let N = length(args)
stack(enumerate(args)) do (i, arg)
new_shape = ones(Int, N)
new_shape[i] = length(arg)
repeat_sizes = collect(Int, map(length, args))
repeat_sizes[i] = 1
return repeat(Lux.Utils.contiguous(reshape(arg, new_shape...)), repeat_sizes...)
end
# Type-stable zero-padding for FFT operations.
# Pads an array with zeros in the first M spatial dimensions.
# The last two dimensions (channels, batch) are not padded.
function pad_zeros_spatial(
x::AbstractArray{T, N}, target_sizes::Dims{M}
) where {T, N, M}
@assert M == N - 2 "target_sizes must have N-2 elements (spatial dimensions only)"
current_sizes = ntuple(i -> size(x, i), Val(M))
# If no padding needed, return the original array
all(current_sizes .== target_sizes) && return x
# Create output array with target sizes + unchanged channel/batch dims
out_size = (target_sizes..., size(x, N - 1), size(x, N))
y = zeros(T, out_size)
# Copy the input data to the beginning of each spatial dimension
src_indices = (ntuple(i -> 1:size(x, i), Val(M))..., :, :)
y[src_indices...] = x
return y
end

function meshgrid(args::Vararg{AbstractVector, N}) where {N}
return stack(enumerate(args)) do (i, arg)
new_shape = ntuple(j -> j == i ? length(arg) : 1, Val(N))
repeat_sizes = ntuple(j -> j == i ? 1 : length(args[j]), Val(N))
return repeat(Lux.Utils.contiguous(reshape(arg, new_shape)), repeat_sizes...)
end
end

Expand Down
4 changes: 4 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
[deps]
AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
Expand All @@ -19,7 +21,9 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
AllocCheck = "0.2"
Aqua = "0.8.7"
BenchmarkTools = "1"
Documenter = "1.5.0"
Enzyme = "0.13.48"
ExplicitImports = "1.9.0"
Expand Down
136 changes: 136 additions & 0 deletions test/perf_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
@testitem "Performance: Type Stability" setup = [SharedTestSetup] begin
using InteractiveUtils
using NeuralOperators
using NeuralOperators: pad_zeros_spatial, expand_pad_dims, meshgrid

# Test that key functions are type-stable (return concrete types, not Any)
@testset "pad_zeros_spatial type stability" begin
# 1D case
x1d = rand(ComplexF32, 16, 5, 5)
result_type = Base.return_types(pad_zeros_spatial, (typeof(x1d), Tuple{Int}))[1]
@test result_type <: AbstractArray{ComplexF32, 3}

# 2D case
x2d = rand(ComplexF32, 16, 16, 5, 5)
result_type = Base.return_types(pad_zeros_spatial, (typeof(x2d), Tuple{Int, Int}))[1]
@test result_type <: AbstractArray{ComplexF32, 4}
end

@testset "meshgrid type stability" begin
r1 = range(0.0f0, 1.0f0; length = 32)
r2 = range(0.0f0, 1.0f0; length = 32)

result_type = Base.return_types(meshgrid, (typeof(r1), typeof(r2)))[1]
@test result_type <: AbstractArray{Float32, 3}
end

@testset "OperatorConv type stability" begin
rng = StableRNG(12345)

# 1D case
sc1d = SpectralConv(2 => 5, (16,))
ps, st = Lux.setup(rng, sc1d)
x = rand(Float32, 1024, 2, 5)

result_type = Base.return_types((sc1d, x, ps, st) -> sc1d(x, ps, st),
(typeof(sc1d), typeof(x), typeof(ps), typeof(st)))[1]
@test result_type <: Tuple{AbstractArray{Float32, 3}, Any}

# 2D case
sc2d = SpectralConv(2 => 5, (16, 16))
ps2d, st2d = Lux.setup(rng, sc2d)
x2d = rand(Float32, 32, 32, 2, 5)

result_type = Base.return_types((sc2d, x2d, ps2d, st2d) -> sc2d(x2d, ps2d, st2d),
(typeof(sc2d), typeof(x2d), typeof(ps2d), typeof(st2d)))[1]
@test result_type <: Tuple{AbstractArray{Float32, 4}, Any}
end
end

@testitem "Performance: Allocation Bounds" setup = [SharedTestSetup] begin
using BenchmarkTools
using NeuralOperators

@testset "SpectralConv allocation bounds" begin
rng = StableRNG(12345)

# 1D case - establish allocation bounds
sc1d = SpectralConv(2 => 5, (16,))
ps1d, st1d = Lux.setup(rng, sc1d)
x1d = rand(Float32, 1024, 2, 5)

# Warmup
sc1d(x1d, ps1d, st1d)

# Check allocations are within expected bounds
# Current baseline: ~256KB, allow 20% margin
allocs = @allocated sc1d(x1d, ps1d, st1d)
@test allocs < 310_000 # 256KB + 20% margin

# 2D case
sc2d = SpectralConv(2 => 5, (16, 16))
ps2d, st2d = Lux.setup(rng, sc2d)
x2d = rand(Float32, 32, 32, 2, 5)

sc2d(x2d, ps2d, st2d)

# Current baseline: ~598KB, allow 20% margin
allocs = @allocated sc2d(x2d, ps2d, st2d)
@test allocs < 720_000 # 600KB + 20% margin
end

@testset "FourierNeuralOperator allocation bounds" begin
rng = StableRNG(12345)

fno = FourierNeuralOperator(; chs = (2, 32, 32, 32, 32, 64, 1), modes = (16,))
ps, st = Lux.setup(rng, fno)
x = rand(Float32, 256, 2, 4)

# Warmup
fno(x, ps, st)

# Current baseline: ~3.1MB, allow 50% margin for CI variance
allocs = @allocated fno(x, ps, st)
@test allocs < 5_000_000 # ~4.8MB max (generous bound for CI)
end

@testset "DeepONet allocation bounds" begin
rng = StableRNG(12345)

deeponet = DeepONet(; branch = (64, 32, 32, 16), trunk = (1, 8, 8, 16))
ps, st = Lux.setup(rng, deeponet)
u = rand(Float32, 64, 5)
y = rand(Float32, 1, 10)

# Warmup
deeponet((u, y), ps, st)

# Current baseline: ~3.7KB, allow 20% margin
allocs = @allocated deeponet((u, y), ps, st)
@test allocs < 5_000

# Allocation count
b = @benchmark $deeponet(($u, $y), $ps, $st) samples = 3 evals = 1
@test b.allocs < 10
end

@testset "NOMAD allocation bounds" begin
rng = StableRNG(12345)

nomad = NOMAD(; approximator = (8, 32, 32, 16), decoder = (18, 16, 8, 8))
ps, st = Lux.setup(rng, nomad)
u = rand(Float32, 8, 5)
y = rand(Float32, 2, 5)

# Warmup
nomad((u, y), ps, st)

# Current baseline: ~3.2KB, allow 20% margin
allocs = @allocated nomad((u, y), ps, st)
@test allocs < 5_000

# Allocation count
b = @benchmark $nomad(($u, $y), $ps, $st) samples = 3 evals = 1
@test b.allocs < 10
end
end
Loading