diff --git a/src/utils.jl b/src/utils.jl index ca3bea0..b31abca 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -28,16 +28,18 @@ function expand_pad_dims(pad_dims::Dims{N}) where {N} return ntuple(i -> isodd(i) ? 0 : pad_dims[i รท 2], 2N) end -function meshgrid(args::AbstractVector...) - return let N = length(args) - stack(enumerate(args)) do (i, arg) - new_shape = ones(Int, N) - new_shape[i] = length(arg) - repeat_sizes = collect(Int, map(length, args)) - repeat_sizes[i] = 1 - return repeat(Lux.Utils.contiguous(reshape(arg, new_shape...)), repeat_sizes...) - end +function meshgrid(args::AbstractVector{T}...) where {T} + N = length(args) + dims = map(length, args) + result = similar(first(args), dims..., N) + + for (i, arg) in enumerate(args) + new_shape = ntuple(j -> j == i ? dims[j] : 1, N) + view_result = selectdim(result, N + 1, i) + view_result .= reshape(arg, new_shape...) end + + return result end function decomposed_activation(f::F, x::Number) where {F}