Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions .JuliaFormatter.toml

This file was deleted.

38 changes: 14 additions & 24 deletions .github/workflows/FormatPR.yml
Original file line number Diff line number Diff line change
@@ -1,29 +1,19 @@
name: FormatPR
name: format-check

on:
schedule:
- cron: '0 0 * * *'
push:
branches:
- 'master'
- 'main'
- 'release-'
tags: '*'
pull_request:

jobs:
build:
runic:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v6
- name: Install JuliaFormatter and format
run: |
julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter", version="1"))'
julia -e 'using JuliaFormatter; format(".")'
# https://github.com/marketplace/actions/create-pull-request
# https://github.com/peter-evans/create-pull-request#reference-example
- name: Create Pull Request
id: cpr
uses: peter-evans/create-pull-request@v7
- uses: actions/checkout@v4
- uses: fredrikekre/runic-action@v1
with:
token: ${{ secrets.GITHUB_TOKEN }}
commit-message: Format .jl files
title: 'Automatic JuliaFormatter.jl run'
branch: auto-juliaformatter-pr
delete-branch: true
labels: formatting, automated pr, no changelog
- name: Check outputs
run: |
echo "Pull Request Number - ${{ steps.cpr.outputs.pull-request-number }}"
echo "Pull Request URL - ${{ steps.cpr.outputs.pull-request-url }}"
version: '1'
24 changes: 12 additions & 12 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,25 @@
using Documenter, NeuralOperators

cp("./docs/Manifest.toml", "./docs/src/assets/Manifest.toml"; force=true)
cp("./docs/Project.toml", "./docs/src/assets/Project.toml"; force=true)
cp("./docs/Manifest.toml", "./docs/src/assets/Manifest.toml"; force = true)
cp("./docs/Project.toml", "./docs/src/assets/Project.toml"; force = true)

ENV["GKSwstype"] = "100"
ENV["DATADEPS_ALWAYS_ACCEPT"] = true

include("pages.jl")

makedocs(;
sitename="NeuralOperators.jl",
clean=true,
doctest=false,
linkcheck=true,
modules=[NeuralOperators],
format=Documenter.HTML(;
prettyurls=get(ENV, "CI", "false") == "true",
canonical="https://docs.sciml.ai/NeuralOperators/stable/",
assets=["assets/favicon.ico"],
sitename = "NeuralOperators.jl",
clean = true,
doctest = false,
linkcheck = true,
modules = [NeuralOperators],
format = Documenter.HTML(;
prettyurls = get(ENV, "CI", "false") == "true",
canonical = "https://docs.sciml.ai/NeuralOperators/stable/",
assets = ["assets/favicon.ico"],
),
pages,
)

deploydocs(; repo="github.com/SciML/NeuralOperators.jl.git", push_preview=true)
deploydocs(; repo = "github.com/SciML/NeuralOperators.jl.git", push_preview = true)
76 changes: 39 additions & 37 deletions src/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ function LuxCore.initialparameters(rng::AbstractRNG, layer::OperatorConv)
in_chs, out_chs = layer.in_chs, layer.out_chs
scale = real(one(eltype(layer.tform))) / (in_chs * out_chs)
return (;
weight=scale * layer.init_weight(
weight = scale * layer.init_weight(
rng, eltype(layer.tform), out_chs, in_chs, layer.prod_modes
)
),
)
end

Expand All @@ -52,22 +52,22 @@ function LuxCore.parameterlength(layer::OperatorConv)
end

function OperatorConv(
ch::Pair{<:Integer,<:Integer},
modes::Dims,
tform::AbstractTransform;
init_weight=glorot_uniform,
)
ch::Pair{<:Integer, <:Integer},
modes::Dims,
tform::AbstractTransform;
init_weight = glorot_uniform,
)
return OperatorConv(ch..., prod(modes), tform, init_weight)
end

function (conv::OperatorConv)(x::AbstractArray{T,N}, ps, st) where {T,N}
function (conv::OperatorConv)(x::AbstractArray{T, N}, ps, st) where {T, N}
x_t = transform(conv.tform, x)
x_tr = truncate_modes(conv.tform, x_t)
x_p = apply_pattern(x_tr, ps.weight)

pad_dims = size(x_t)[1:(end - 2)] .- size(x_p)[1:(end - 2)]
x_padded = pad_constant(
x_p, expand_pad_dims(pad_dims), false; dims=ntuple(identity, ndims(x_p) - 2)
x_p, expand_pad_dims(pad_dims), false; dims = ntuple(identity, ndims(x_p) - 2)
)
out = inverse(conv.tform, x_padded, x)

Expand All @@ -88,8 +88,8 @@ julia> SpectralConv(2 => 5, (16,));
```
"""
function SpectralConv(
ch::Pair{<:Integer,<:Integer}, modes::Dims; shift::Bool=false, kwargs...
)
ch::Pair{<:Integer, <:Integer}, modes::Dims; shift::Bool = false, kwargs...
)
return OperatorConv(ch, modes, FourierTransform{ComplexF32}(modes, shift); kwargs...)
end

Expand Down Expand Up @@ -121,18 +121,18 @@ julia> OperatorKernel(2 => 5, (16,), FourierTransform{ComplexF64}((16,)));
end

function OperatorKernel(
ch::Pair{<:Integer,<:Integer},
modes::Dims{N},
transform::AbstractTransform,
act=identity;
stabilizer=identity,
complex_data::Bool=false,
fno_skip::Symbol=:linear,
channel_mlp_skip::Symbol=:soft_gating,
use_channel_mlp::Bool=false,
channel_mlp_expansion::Real=0.5,
kwargs...,
) where {N}
ch::Pair{<:Integer, <:Integer},
modes::Dims{N},
transform::AbstractTransform,
act = identity;
stabilizer = identity,
complex_data::Bool = false,
fno_skip::Symbol = :linear,
channel_mlp_skip::Symbol = :soft_gating,
use_channel_mlp::Bool = false,
channel_mlp_expansion::Real = 0.5,
kwargs...,
) where {N}
in_chs, out_chs = ch

complex_data && (stabilizer = Base.Fix1(decomposed_activation, stabilizer))
Expand Down Expand Up @@ -205,8 +205,8 @@ julia> SpectralKernel(2 => 5, (16,));
```
"""
function SpectralKernel(
ch::Pair{<:Integer,<:Integer}, modes::Dims, act=identity; shift::Bool=false, kwargs...
)
ch::Pair{<:Integer, <:Integer}, modes::Dims, act = identity; shift::Bool = false, kwargs...
)
return OperatorKernel(
ch, modes, FourierTransform{ComplexF32}(modes, shift), act; kwargs...
)
Expand All @@ -218,26 +218,28 @@ end
Appends a uniform grid embedding to the input data along the penultimate dimension.
"""
@concrete struct GridEmbedding <: AbstractLuxLayer
grid_boundaries <: Vector{<:Tuple{<:Real,<:Real}}
grid_boundaries <: Vector{<:Tuple{<:Real, <:Real}}
end

function Base.show(io::IO, layer::GridEmbedding)
return print(io, "GridEmbedding(", join(layer.grid_boundaries, ", "), ")")
end

function (layer::GridEmbedding)(x::AbstractArray{T,N}, ps, st) where {T,N}
function (layer::GridEmbedding)(x::AbstractArray{T, N}, ps, st) where {T, N}
@assert length(layer.grid_boundaries) == N - 2

grid = meshgrid(map(enumerate(layer.grid_boundaries)) do (i, (min, max))
return range(T(min), T(max); length=size(x, i))
end...)
grid = meshgrid(
map(enumerate(layer.grid_boundaries)) do (i, (min, max))
return range(T(min), T(max); length = size(x, i))
end...
)

grid = repeat(
Lux.Utils.contiguous(reshape(grid, size(grid)..., 1)),
ntuple(Returns(1), N - 1)...,
size(x, N),
)
return cat(grid, x; dims=N - 1), st
return cat(grid, x; dims = N - 1), st
end

"""
Expand All @@ -252,19 +254,19 @@ end

function LuxCore.initialparameters(rng::AbstractRNG, layer::ComplexDecomposedLayer)
return (;
real=LuxCore.initialparameters(rng, layer.layer),
imag=LuxCore.initialparameters(rng, layer.layer),
real = LuxCore.initialparameters(rng, layer.layer),
imag = LuxCore.initialparameters(rng, layer.layer),
)
end

function LuxCore.initialstates(rng::AbstractRNG, layer::ComplexDecomposedLayer)
return (;
real=LuxCore.initialstates(rng, layer.layer),
imag=LuxCore.initialstates(rng, layer.layer),
real = LuxCore.initialstates(rng, layer.layer),
imag = LuxCore.initialstates(rng, layer.layer),
)
end

function (layer::ComplexDecomposedLayer)(x::AbstractArray{T,N}, ps, st) where {T,N}
function (layer::ComplexDecomposedLayer)(x::AbstractArray{T, N}, ps, st) where {T, N}
rx = real.(x)
ix = imag.(x)

Expand All @@ -275,7 +277,7 @@ function (layer::ComplexDecomposedLayer)(x::AbstractArray{T,N}, ps, st) where {T
ifn_ix, st_imag = layer.layer(ix, ps.imag, st_imag)

out = Complex.(rfn_rx .- ifn_ix, rfn_ix .+ ifn_rx)
return out, (; real=st_real, imag=st_imag)
return out, (; real = st_real, imag = st_imag)
end

"""
Expand Down
24 changes: 12 additions & 12 deletions src/models/deeponet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ end
function DeepONet(branch, trunk)
return DeepONet(
Chain(
Parallel(*; branch=Chain(branch, WrappedFunction(adjoint)), trunk=trunk),
Parallel(*; branch = Chain(branch, WrappedFunction(adjoint)), trunk = trunk),
WrappedFunction(adjoint),
),
)
Expand Down Expand Up @@ -94,11 +94,11 @@ julia> size(first(deeponet((u, y), ps, st)))
```
"""
function DeepONet(;
branch=(64, 32, 32, 16),
trunk=(1, 8, 8, 16),
branch_activation=identity,
trunk_activation=identity,
)
branch = (64, 32, 32, 16),
trunk = (1, 8, 8, 16),
branch_activation = identity,
trunk_activation = identity,
)

# checks for last dimension size
@assert branch[end] == trunk[end] "Branch and Trunk net must share the same amount \
Expand All @@ -108,18 +108,18 @@ function DeepONet(;
branch_net = Chain(
[
Dense(
branch[i] => branch[i + 1],
ifelse(i == length(branch) - 1, identity, branch_activation),
) for i in 1:(length(branch) - 1)
branch[i] => branch[i + 1],
ifelse(i == length(branch) - 1, identity, branch_activation),
) for i in 1:(length(branch) - 1)
]...,
)

trunk_net = Chain(
[
Dense(
trunk[i] => trunk[i + 1],
ifelse(i == length(trunk) - 1, identity, trunk_activation),
) for i in 1:(length(trunk) - 1)
trunk[i] => trunk[i + 1],
ifelse(i == length(trunk) - 1, identity, trunk_activation),
) for i in 1:(length(trunk) - 1)
]...,
)

Expand Down
62 changes: 31 additions & 31 deletions src/models/fno.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ julia> size(first(fno(u, ps, st)))
end

function FourierNeuralOperator(
σ=gelu; chs::Dims{C}=(2, 64, 64, 64, 64, 64, 128, 1), modes::Dims{M}=(16,), kwargs...
) where {C,M}
σ = gelu; chs::Dims{C} = (2, 64, 64, 64, 64, 64, 128, 1), modes::Dims{M} = (16,), kwargs...
) where {C, M}
@assert length(chs) ≥ 5

return FourierNeuralOperator(
Expand All @@ -54,7 +54,7 @@ function FourierNeuralOperator(
Chain(
[
SpectralKernel(chs[i] => chs[i + 1], modes, σ; kwargs...) for
i in 2:(C - 3)
i in 2:(C - 3)
]...,
),
Chain(
Expand Down Expand Up @@ -111,23 +111,23 @@ Constructor for a Fourier neural operator (FNO) model.
- `shift`: Whether to apply `fftshift` before truncating the modes.
"""
function FourierNeuralOperator(
modes::Dims{N},
in_channels::Integer,
out_channels::Integer,
hidden_channels::Integer;
num_layers::Integer=4,
lifting_channel_ratio::Integer=2,
projection_channel_ratio::Integer=2,
positional_embedding::Union{Symbol,AbstractLuxLayer}=:grid, # :grid | :none
activation=gelu,
use_channel_mlp::Bool=true,
channel_mlp_expansion::Real=0.5,
channel_mlp_skip::Symbol=:soft_gating,
fno_skip::Symbol=:linear,
complex_data::Bool=false,
stabilizer=tanh,
shift::Bool=false,
) where {N}
modes::Dims{N},
in_channels::Integer,
out_channels::Integer,
hidden_channels::Integer;
num_layers::Integer = 4,
lifting_channel_ratio::Integer = 2,
projection_channel_ratio::Integer = 2,
positional_embedding::Union{Symbol, AbstractLuxLayer} = :grid, # :grid | :none
activation = gelu,
use_channel_mlp::Bool = true,
channel_mlp_expansion::Real = 0.5,
channel_mlp_skip::Symbol = :soft_gating,
fno_skip::Symbol = :linear,
complex_data::Bool = false,
stabilizer = tanh,
shift::Bool = false,
) where {N}
lifting_channels = hidden_channels * lifting_channel_ratio
projection_channels = out_channels * projection_channel_ratio

Expand Down Expand Up @@ -155,17 +155,17 @@ function FourierNeuralOperator(
fno_blocks = Chain(
[
SpectralKernel(
hidden_channels => hidden_channels,
modes,
activation;
stabilizer,
shift,
use_channel_mlp,
channel_mlp_expansion,
channel_mlp_skip,
fno_skip,
complex_data,
) for _ in 1:num_layers
hidden_channels => hidden_channels,
modes,
activation;
stabilizer,
shift,
use_channel_mlp,
channel_mlp_expansion,
channel_mlp_skip,
fno_skip,
complex_data,
) for _ in 1:num_layers
]...,
)

Expand Down
Loading
Loading