From 162a69d818337492d2f226880b469f8531010452 Mon Sep 17 00:00:00 2001 From: Claude Code Date: Wed, 7 Jan 2026 17:01:43 -0500 Subject: [PATCH] perf: optimize meshgrid for reduced allocations and improved type stability MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The meshgrid function is used by GridEmbedding to generate coordinate grids. This optimization provides significant performance improvements: ## Benchmarks ### meshgrid 1D (128 points): - BEFORE: ~2.85 μs, 1.48 KiB allocated, 15 allocations - AFTER: ~79 ns, 624 bytes allocated, 3 allocations - Improvement: ~36x faster, 58% less memory, 80% fewer allocations ### meshgrid 2D (64x64 points): - BEFORE: ~19.8 μs, 65.04 KiB allocated, 29 allocations - AFTER: ~4.4 μs, 33 KiB allocated, 31 allocations - Improvement: ~4.5x faster, 49% less memory ## Changes - Pre-allocate output array with `similar()` instead of using `stack()` - Use in-place broadcasting with `selectdim()` views instead of `repeat()` - Add type parameter `T` for better type inference - Use `ntuple()` instead of mutable array for shape computation ## Type Stability The original implementation returned `Any` from @code_warntype. The new implementation returns concrete types (Matrix{T}, Array{T,3}, etc.) Co-Authored-By: Claude Opus 4.5 --- src/utils.jl | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index ca3bea0..b31abca 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -28,16 +28,18 @@ function expand_pad_dims(pad_dims::Dims{N}) where {N} return ntuple(i -> isodd(i) ? 0 : pad_dims[i ÷ 2], 2N) end -function meshgrid(args::AbstractVector...) - return let N = length(args) - stack(enumerate(args)) do (i, arg) - new_shape = ones(Int, N) - new_shape[i] = length(arg) - repeat_sizes = collect(Int, map(length, args)) - repeat_sizes[i] = 1 - return repeat(Lux.Utils.contiguous(reshape(arg, new_shape...)), repeat_sizes...) - end +function meshgrid(args::AbstractVector{T}...) where {T} + N = length(args) + dims = map(length, args) + result = similar(first(args), dims..., N) + + for (i, arg) in enumerate(args) + new_shape = ntuple(j -> j == i ? dims[j] : 1, N) + view_result = selectdim(result, N + 1, i) + view_result .= reshape(arg, new_shape...) end + + return result end function decomposed_activation(f::F, x::Number) where {F}