diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml deleted file mode 100644 index 28be623..0000000 --- a/.JuliaFormatter.toml +++ /dev/null @@ -1,3 +0,0 @@ -style = "blue" -pipe_to_function_call = false -always_use_return = true diff --git a/.github/workflows/FormatPR.yml b/.github/workflows/FormatPR.yml index 81a57a3..6762c6f 100644 --- a/.github/workflows/FormatPR.yml +++ b/.github/workflows/FormatPR.yml @@ -1,29 +1,19 @@ -name: FormatPR +name: format-check + on: - schedule: - - cron: '0 0 * * *' + push: + branches: + - 'master' + - 'main' + - 'release-' + tags: '*' + pull_request: + jobs: - build: + runic: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v6 - - name: Install JuliaFormatter and format - run: | - julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter", version="1"))' - julia -e 'using JuliaFormatter; format(".")' - # https://github.com/marketplace/actions/create-pull-request - # https://github.com/peter-evans/create-pull-request#reference-example - - name: Create Pull Request - id: cpr - uses: peter-evans/create-pull-request@v7 + - uses: actions/checkout@v4 + - uses: fredrikekre/runic-action@v1 with: - token: ${{ secrets.GITHUB_TOKEN }} - commit-message: Format .jl files - title: 'Automatic JuliaFormatter.jl run' - branch: auto-juliaformatter-pr - delete-branch: true - labels: formatting, automated pr, no changelog - - name: Check outputs - run: | - echo "Pull Request Number - ${{ steps.cpr.outputs.pull-request-number }}" - echo "Pull Request URL - ${{ steps.cpr.outputs.pull-request-url }}" \ No newline at end of file + version: '1' diff --git a/docs/make.jl b/docs/make.jl index e9048a7..dd7f093 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,7 +1,7 @@ using Documenter, NeuralOperators -cp("./docs/Manifest.toml", "./docs/src/assets/Manifest.toml"; force=true) -cp("./docs/Project.toml", "./docs/src/assets/Project.toml"; force=true) +cp("./docs/Manifest.toml", "./docs/src/assets/Manifest.toml"; force = true) +cp("./docs/Project.toml", "./docs/src/assets/Project.toml"; force = true) ENV["GKSwstype"] = "100" ENV["DATADEPS_ALWAYS_ACCEPT"] = true @@ -9,17 +9,17 @@ ENV["DATADEPS_ALWAYS_ACCEPT"] = true include("pages.jl") makedocs(; - sitename="NeuralOperators.jl", - clean=true, - doctest=false, - linkcheck=true, - modules=[NeuralOperators], - format=Documenter.HTML(; - prettyurls=get(ENV, "CI", "false") == "true", - canonical="https://docs.sciml.ai/NeuralOperators/stable/", - assets=["assets/favicon.ico"], + sitename = "NeuralOperators.jl", + clean = true, + doctest = false, + linkcheck = true, + modules = [NeuralOperators], + format = Documenter.HTML(; + prettyurls = get(ENV, "CI", "false") == "true", + canonical = "https://docs.sciml.ai/NeuralOperators/stable/", + assets = ["assets/favicon.ico"], ), pages, ) -deploydocs(; repo="github.com/SciML/NeuralOperators.jl.git", push_preview=true) +deploydocs(; repo = "github.com/SciML/NeuralOperators.jl.git", push_preview = true) diff --git a/src/layers.jl b/src/layers.jl index 2fffaaa..e5d47dc 100644 --- a/src/layers.jl +++ b/src/layers.jl @@ -41,9 +41,9 @@ function LuxCore.initialparameters(rng::AbstractRNG, layer::OperatorConv) in_chs, out_chs = layer.in_chs, layer.out_chs scale = real(one(eltype(layer.tform))) / (in_chs * out_chs) return (; - weight=scale * layer.init_weight( + weight = scale * layer.init_weight( rng, eltype(layer.tform), out_chs, in_chs, layer.prod_modes - ) + ), ) end @@ -52,22 +52,22 @@ function LuxCore.parameterlength(layer::OperatorConv) end function OperatorConv( - ch::Pair{<:Integer,<:Integer}, - modes::Dims, - tform::AbstractTransform; - init_weight=glorot_uniform, -) + ch::Pair{<:Integer, <:Integer}, + modes::Dims, + tform::AbstractTransform; + init_weight = glorot_uniform, + ) return OperatorConv(ch..., prod(modes), tform, init_weight) end -function (conv::OperatorConv)(x::AbstractArray{T,N}, ps, st) where {T,N} +function (conv::OperatorConv)(x::AbstractArray{T, N}, ps, st) where {T, N} x_t = transform(conv.tform, x) x_tr = truncate_modes(conv.tform, x_t) x_p = apply_pattern(x_tr, ps.weight) pad_dims = size(x_t)[1:(end - 2)] .- size(x_p)[1:(end - 2)] x_padded = pad_constant( - x_p, expand_pad_dims(pad_dims), false; dims=ntuple(identity, ndims(x_p) - 2) + x_p, expand_pad_dims(pad_dims), false; dims = ntuple(identity, ndims(x_p) - 2) ) out = inverse(conv.tform, x_padded, x) @@ -88,8 +88,8 @@ julia> SpectralConv(2 => 5, (16,)); ``` """ function SpectralConv( - ch::Pair{<:Integer,<:Integer}, modes::Dims; shift::Bool=false, kwargs... -) + ch::Pair{<:Integer, <:Integer}, modes::Dims; shift::Bool = false, kwargs... + ) return OperatorConv(ch, modes, FourierTransform{ComplexF32}(modes, shift); kwargs...) end @@ -121,18 +121,18 @@ julia> OperatorKernel(2 => 5, (16,), FourierTransform{ComplexF64}((16,))); end function OperatorKernel( - ch::Pair{<:Integer,<:Integer}, - modes::Dims{N}, - transform::AbstractTransform, - act=identity; - stabilizer=identity, - complex_data::Bool=false, - fno_skip::Symbol=:linear, - channel_mlp_skip::Symbol=:soft_gating, - use_channel_mlp::Bool=false, - channel_mlp_expansion::Real=0.5, - kwargs..., -) where {N} + ch::Pair{<:Integer, <:Integer}, + modes::Dims{N}, + transform::AbstractTransform, + act = identity; + stabilizer = identity, + complex_data::Bool = false, + fno_skip::Symbol = :linear, + channel_mlp_skip::Symbol = :soft_gating, + use_channel_mlp::Bool = false, + channel_mlp_expansion::Real = 0.5, + kwargs..., + ) where {N} in_chs, out_chs = ch complex_data && (stabilizer = Base.Fix1(decomposed_activation, stabilizer)) @@ -205,8 +205,8 @@ julia> SpectralKernel(2 => 5, (16,)); ``` """ function SpectralKernel( - ch::Pair{<:Integer,<:Integer}, modes::Dims, act=identity; shift::Bool=false, kwargs... -) + ch::Pair{<:Integer, <:Integer}, modes::Dims, act = identity; shift::Bool = false, kwargs... + ) return OperatorKernel( ch, modes, FourierTransform{ComplexF32}(modes, shift), act; kwargs... ) @@ -218,26 +218,28 @@ end Appends a uniform grid embedding to the input data along the penultimate dimension. """ @concrete struct GridEmbedding <: AbstractLuxLayer - grid_boundaries <: Vector{<:Tuple{<:Real,<:Real}} + grid_boundaries <: Vector{<:Tuple{<:Real, <:Real}} end function Base.show(io::IO, layer::GridEmbedding) return print(io, "GridEmbedding(", join(layer.grid_boundaries, ", "), ")") end -function (layer::GridEmbedding)(x::AbstractArray{T,N}, ps, st) where {T,N} +function (layer::GridEmbedding)(x::AbstractArray{T, N}, ps, st) where {T, N} @assert length(layer.grid_boundaries) == N - 2 - grid = meshgrid(map(enumerate(layer.grid_boundaries)) do (i, (min, max)) - return range(T(min), T(max); length=size(x, i)) - end...) + grid = meshgrid( + map(enumerate(layer.grid_boundaries)) do (i, (min, max)) + return range(T(min), T(max); length = size(x, i)) + end... + ) grid = repeat( Lux.Utils.contiguous(reshape(grid, size(grid)..., 1)), ntuple(Returns(1), N - 1)..., size(x, N), ) - return cat(grid, x; dims=N - 1), st + return cat(grid, x; dims = N - 1), st end """ @@ -252,19 +254,19 @@ end function LuxCore.initialparameters(rng::AbstractRNG, layer::ComplexDecomposedLayer) return (; - real=LuxCore.initialparameters(rng, layer.layer), - imag=LuxCore.initialparameters(rng, layer.layer), + real = LuxCore.initialparameters(rng, layer.layer), + imag = LuxCore.initialparameters(rng, layer.layer), ) end function LuxCore.initialstates(rng::AbstractRNG, layer::ComplexDecomposedLayer) return (; - real=LuxCore.initialstates(rng, layer.layer), - imag=LuxCore.initialstates(rng, layer.layer), + real = LuxCore.initialstates(rng, layer.layer), + imag = LuxCore.initialstates(rng, layer.layer), ) end -function (layer::ComplexDecomposedLayer)(x::AbstractArray{T,N}, ps, st) where {T,N} +function (layer::ComplexDecomposedLayer)(x::AbstractArray{T, N}, ps, st) where {T, N} rx = real.(x) ix = imag.(x) @@ -275,7 +277,7 @@ function (layer::ComplexDecomposedLayer)(x::AbstractArray{T,N}, ps, st) where {T ifn_ix, st_imag = layer.layer(ix, ps.imag, st_imag) out = Complex.(rfn_rx .- ifn_ix, rfn_ix .+ ifn_rx) - return out, (; real=st_real, imag=st_imag) + return out, (; real = st_real, imag = st_imag) end """ diff --git a/src/models/deeponet.jl b/src/models/deeponet.jl index 3a855f3..4ccca34 100644 --- a/src/models/deeponet.jl +++ b/src/models/deeponet.jl @@ -50,7 +50,7 @@ end function DeepONet(branch, trunk) return DeepONet( Chain( - Parallel(*; branch=Chain(branch, WrappedFunction(adjoint)), trunk=trunk), + Parallel(*; branch = Chain(branch, WrappedFunction(adjoint)), trunk = trunk), WrappedFunction(adjoint), ), ) @@ -94,11 +94,11 @@ julia> size(first(deeponet((u, y), ps, st))) ``` """ function DeepONet(; - branch=(64, 32, 32, 16), - trunk=(1, 8, 8, 16), - branch_activation=identity, - trunk_activation=identity, -) + branch = (64, 32, 32, 16), + trunk = (1, 8, 8, 16), + branch_activation = identity, + trunk_activation = identity, + ) # checks for last dimension size @assert branch[end] == trunk[end] "Branch and Trunk net must share the same amount \ @@ -108,18 +108,18 @@ function DeepONet(; branch_net = Chain( [ Dense( - branch[i] => branch[i + 1], - ifelse(i == length(branch) - 1, identity, branch_activation), - ) for i in 1:(length(branch) - 1) + branch[i] => branch[i + 1], + ifelse(i == length(branch) - 1, identity, branch_activation), + ) for i in 1:(length(branch) - 1) ]..., ) trunk_net = Chain( [ Dense( - trunk[i] => trunk[i + 1], - ifelse(i == length(trunk) - 1, identity, trunk_activation), - ) for i in 1:(length(trunk) - 1) + trunk[i] => trunk[i + 1], + ifelse(i == length(trunk) - 1, identity, trunk_activation), + ) for i in 1:(length(trunk) - 1) ]..., ) diff --git a/src/models/fno.jl b/src/models/fno.jl index 1edd877..bf2a77d 100644 --- a/src/models/fno.jl +++ b/src/models/fno.jl @@ -44,8 +44,8 @@ julia> size(first(fno(u, ps, st))) end function FourierNeuralOperator( - σ=gelu; chs::Dims{C}=(2, 64, 64, 64, 64, 64, 128, 1), modes::Dims{M}=(16,), kwargs... -) where {C,M} + σ = gelu; chs::Dims{C} = (2, 64, 64, 64, 64, 64, 128, 1), modes::Dims{M} = (16,), kwargs... + ) where {C, M} @assert length(chs) ≥ 5 return FourierNeuralOperator( @@ -54,7 +54,7 @@ function FourierNeuralOperator( Chain( [ SpectralKernel(chs[i] => chs[i + 1], modes, σ; kwargs...) for - i in 2:(C - 3) + i in 2:(C - 3) ]..., ), Chain( @@ -111,23 +111,23 @@ Constructor for a Fourier neural operator (FNO) model. - `shift`: Whether to apply `fftshift` before truncating the modes. """ function FourierNeuralOperator( - modes::Dims{N}, - in_channels::Integer, - out_channels::Integer, - hidden_channels::Integer; - num_layers::Integer=4, - lifting_channel_ratio::Integer=2, - projection_channel_ratio::Integer=2, - positional_embedding::Union{Symbol,AbstractLuxLayer}=:grid, # :grid | :none - activation=gelu, - use_channel_mlp::Bool=true, - channel_mlp_expansion::Real=0.5, - channel_mlp_skip::Symbol=:soft_gating, - fno_skip::Symbol=:linear, - complex_data::Bool=false, - stabilizer=tanh, - shift::Bool=false, -) where {N} + modes::Dims{N}, + in_channels::Integer, + out_channels::Integer, + hidden_channels::Integer; + num_layers::Integer = 4, + lifting_channel_ratio::Integer = 2, + projection_channel_ratio::Integer = 2, + positional_embedding::Union{Symbol, AbstractLuxLayer} = :grid, # :grid | :none + activation = gelu, + use_channel_mlp::Bool = true, + channel_mlp_expansion::Real = 0.5, + channel_mlp_skip::Symbol = :soft_gating, + fno_skip::Symbol = :linear, + complex_data::Bool = false, + stabilizer = tanh, + shift::Bool = false, + ) where {N} lifting_channels = hidden_channels * lifting_channel_ratio projection_channels = out_channels * projection_channel_ratio @@ -155,17 +155,17 @@ function FourierNeuralOperator( fno_blocks = Chain( [ SpectralKernel( - hidden_channels => hidden_channels, - modes, - activation; - stabilizer, - shift, - use_channel_mlp, - channel_mlp_expansion, - channel_mlp_skip, - fno_skip, - complex_data, - ) for _ in 1:num_layers + hidden_channels => hidden_channels, + modes, + activation; + stabilizer, + shift, + use_channel_mlp, + channel_mlp_expansion, + channel_mlp_skip, + fno_skip, + complex_data, + ) for _ in 1:num_layers ]..., ) diff --git a/src/models/nomad.jl b/src/models/nomad.jl index 5b671ab..958e12f 100644 --- a/src/models/nomad.jl +++ b/src/models/nomad.jl @@ -39,7 +39,7 @@ julia> size(first(nomad((u, y), ps, st))) end function NOMAD(approximator, decoder) - return NOMAD(Chain(; approximator=Parallel(vcat, approximator, NoOpLayer()), decoder)) + return NOMAD(Chain(; approximator = Parallel(vcat, approximator, NoOpLayer()), decoder)) end """ @@ -81,22 +81,22 @@ julia> size(first(nomad((u, y), ps, st))) ``` """ function NOMAD(; - approximator=(8, 32, 32, 16), - decoder=(18, 16, 8, 8), - approximator_activation=identity, - decoder_activation=identity, -) + approximator = (8, 32, 32, 16), + decoder = (18, 16, 8, 8), + approximator_activation = identity, + decoder_activation = identity, + ) approximator_net = Chain( [ Dense(approximator[i] => approximator[i + 1], approximator_activation) for - i in 1:(length(approximator) - 1) + i in 1:(length(approximator) - 1) ]..., ) decoder_net = Chain( [ Dense(decoder[i] => decoder[i + 1], decoder_activation) for - i in 1:(length(decoder) - 1) + i in 1:(length(decoder) - 1) ]..., ) diff --git a/src/transform.jl b/src/transform.jl index e4d7e3b..695e8d4 100644 --- a/src/transform.jl +++ b/src/transform.jl @@ -25,13 +25,13 @@ A concrete implementation of `AbstractTransform` for Fourier transforms. If `shift` is `true`, we apply a `fftshift` before truncating the modes. """ -struct FourierTransform{T,M} <: AbstractTransform{T} +struct FourierTransform{T, M} <: AbstractTransform{T} modes::M shift::Bool end -function FourierTransform{T}(modes::Dims, shift::Bool=false) where {T} - return FourierTransform{T,typeof(modes)}(modes, shift) +function FourierTransform{T}(modes::Dims, shift::Bool = false) where {T} + return FourierTransform{T, typeof(modes)}(modes, shift) end function Base.show(io::IO, ft::FourierTransform) @@ -58,8 +58,8 @@ end truncate_modes(ft::FourierTransform, x_fft::AbstractArray) = low_pass(ft, x_fft) function inverse( - ft::FourierTransform, x_fft::AbstractArray{T,N}, x::AbstractArray{T2,N} -) where {T,T2,N} + ft::FourierTransform, x_fft::AbstractArray{T, N}, x::AbstractArray{T2, N} + ) where {T, T2, N} complex_data = Lux.Utils.eltype(x) <: Complex if ft.shift && ndims(ft) > 1 diff --git a/src/utils.jl b/src/utils.jl index 10e2347..ca3bea0 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,6 +1,6 @@ function apply_pattern( - x_tr::AbstractArray{T1,N}, weights::AbstractArray{T2,3} -) where {T1,T2,N} + x_tr::AbstractArray{T1, N}, weights::AbstractArray{T2, 3} + ) where {T1, T2, N} x_size = size(x_tr) x_flat = reshape(x_tr, :, x_size[N - 1], x_size[N]) diff --git a/test/deeponet_tests.jl b/test/deeponet_tests.jl index ae9608d..3f32e1d 100644 --- a/test/deeponet_tests.jl +++ b/test/deeponet_tests.jl @@ -3,20 +3,20 @@ setups = [ ( - u_size=(64, 5), - y_size=(1, 10), - out_size=(10, 5), - branch=(64, 32, 32, 16), - trunk=(1, 8, 8, 16), - name="Scalar", + u_size = (64, 5), + y_size = (1, 10), + out_size = (10, 5), + branch = (64, 32, 32, 16), + trunk = (1, 8, 8, 16), + name = "Scalar", ), ( - u_size=(64, 5), - y_size=(4, 10), - out_size=(10, 5), - branch=(64, 32, 32, 16), - trunk=(4, 8, 8, 16), - name="Vector", + u_size = (64, 5), + y_size = (4, 10), + out_size = (10, 5), + branch = (64, 32, 32, 16), + trunk = (4, 8, 8, 16), + name = "Vector", ), ] @@ -25,7 +25,7 @@ @testset "$(setup.name)" for setup in setups u = rand(Float32, setup.u_size...) y = rand(Float32, setup.y_size...) - deeponet = DeepONet(; branch=setup.branch, trunk=setup.trunk) + deeponet = DeepONet(; branch = setup.branch, trunk = setup.trunk) ps, st = Lux.setup(rng, deeponet) @@ -39,8 +39,8 @@ ∂u_zyg, ∂ps_zyg = zygote_gradient(deeponet, (u, y), ps, st) ∂u_ra, ∂ps_ra = Reactant.with_config(; - dot_general_precision=PrecisionConfig.HIGH, - convolution_precision=PrecisionConfig.HIGH, + dot_general_precision = PrecisionConfig.HIGH, + convolution_precision = PrecisionConfig.HIGH, ) do @jit enzyme_gradient(deeponet, (u_ra, y_ra), ps_ra, st_ra) end @@ -48,7 +48,7 @@ @test ∂u_zyg[1] ≈ ∂u_ra[1] atol = 1.0f-2 rtol = 1.0f-2 @test ∂u_zyg[2] ≈ ∂u_ra[2] atol = 1.0f-2 rtol = 1.0f-2 - @test check_approx(∂ps_zyg, ∂ps_ra; atol=1.0f-2, rtol=1.0f-2) + @test check_approx(∂ps_zyg, ∂ps_ra; atol = 1.0f-2, rtol = 1.0f-2) end end end diff --git a/test/fno_tests.jl b/test/fno_tests.jl index 1fd0227..41cded5 100644 --- a/test/fno_tests.jl +++ b/test/fno_tests.jl @@ -3,25 +3,25 @@ setups = [ ( - modes=(16,), - chs=(2, 64, 64, 64, 64, 64, 128, 1), - x_size=(1024, 2, 5), - y_size=(1024, 1, 5), - shift=false, + modes = (16,), + chs = (2, 64, 64, 64, 64, 64, 128, 1), + x_size = (1024, 2, 5), + y_size = (1024, 1, 5), + shift = false, ), ( - modes=(16, 16), - chs=(2, 64, 64, 64, 64, 64, 128, 4), - x_size=(32, 32, 2, 5), - y_size=(32, 32, 4, 5), - shift=false, + modes = (16, 16), + chs = (2, 64, 64, 64, 64, 64, 128, 4), + x_size = (32, 32, 2, 5), + y_size = (32, 32, 4, 5), + shift = false, ), ( - modes=(16, 16), - chs=(2, 64, 64, 64, 64, 64, 128, 4), - x_size=(32, 32, 2, 5), - y_size=(32, 32, 4, 5), - shift=true, + modes = (16, 16), + chs = (2, 64, 64, 64, 64, 64, 128, 4), + x_size = (32, 32, 2, 5), + y_size = (32, 32, 4, 5), + shift = true, ), ] @@ -40,8 +40,8 @@ res = first(fno(x, ps, st)) res_ra, _ = Reactant.with_config(; - dot_general_precision=PrecisionConfig.HIGH, - convolution_precision=PrecisionConfig.HIGH, + dot_general_precision = PrecisionConfig.HIGH, + convolution_precision = PrecisionConfig.HIGH, ) do @jit fno(x_ra, ps_ra, st_ra) end @@ -49,7 +49,7 @@ @test begin l2, l1 = train!( - MSELoss(), AutoEnzyme(), fno, ps_ra, st_ra, [(x_ra, y_ra)]; epochs=10 + MSELoss(), AutoEnzyme(), fno, ps_ra, st_ra, [(x_ra, y_ra)]; epochs = 10 ) l2 < l1 end @@ -58,8 +58,8 @@ ∂x_zyg, ∂ps_zyg = zygote_gradient(fno, x, ps, st) ∂x_ra, ∂ps_ra = Reactant.with_config(; - dot_general_precision=PrecisionConfig.HIGH, - convolution_precision=PrecisionConfig.HIGH, + dot_general_precision = PrecisionConfig.HIGH, + convolution_precision = PrecisionConfig.HIGH, ) do @jit enzyme_gradient(fno, x_ra, ps_ra, st_ra) end @@ -67,7 +67,7 @@ # TODO: is zygote off here? @test ∂x_zyg ≈ ∂x_ra atol = 1.0f-2 rtol = 1.0f-2 skip = setup.shift - @test check_approx(∂ps_zyg, ∂ps_ra; atol=1.0f-2, rtol=1.0f-2) skip = setup.shift + @test check_approx(∂ps_zyg, ∂ps_ra; atol = 1.0f-2, rtol = 1.0f-2) skip = setup.shift end end end diff --git a/test/layers_tests.jl b/test/layers_tests.jl index 4444ca3..ae2e864 100644 --- a/test/layers_tests.jl +++ b/test/layers_tests.jl @@ -3,15 +3,15 @@ opconv = [SpectralConv, SpectralKernel] setups = [ - (; m=(16,), x_size=(1024, 2, 5), y_size=(1024, 16, 5), shift=false), - (; m=(10, 10), x_size=(22, 22, 1, 5), y_size=(22, 22, 16, 5), shift=false), - (; m=(10, 10), x_size=(22, 22, 1, 5), y_size=(22, 22, 16, 5), shift=true), + (; m = (16,), x_size = (1024, 2, 5), y_size = (1024, 16, 5), shift = false), + (; m = (10, 10), x_size = (22, 22, 1, 5), y_size = (22, 22, 16, 5), shift = false), + (; m = (10, 10), x_size = (22, 22, 1, 5), y_size = (22, 22, 16, 5), shift = true), ] rdev = reactant_device() @testset "$(op) $(length(setup.m))D | shift=$(setup.shift)" for op in opconv, - setup in setups + setup in setups in_chs = setup.x_size[end - 1] out_chs = setup.y_size[end - 1] @@ -31,8 +31,8 @@ y_ra = rdev(rand(rng, Float32, setup.y_size...)) res_ra, _ = Reactant.with_config(; - dot_general_precision=PrecisionConfig.HIGH, - convolution_precision=PrecisionConfig.HIGH, + dot_general_precision = PrecisionConfig.HIGH, + convolution_precision = PrecisionConfig.HIGH, ) do @jit m(x_ra, ps_ra, st_ra) end @@ -42,7 +42,7 @@ @test begin l2, l1 = train!( - MSELoss(), AutoEnzyme(), m, ps_ra, st_ra, [(x_ra, y_ra)]; epochs=10 + MSELoss(), AutoEnzyme(), m, ps_ra, st_ra, [(x_ra, y_ra)]; epochs = 10 ) l2 < l1 end @@ -51,8 +51,8 @@ ∂x_zyg, ∂ps_zyg = zygote_gradient(m, x, ps, st) ∂x_ra, ∂ps_ra = Reactant.with_config(; - dot_general_precision=PrecisionConfig.HIGH, - convolution_precision=PrecisionConfig.HIGH, + dot_general_precision = PrecisionConfig.HIGH, + convolution_precision = PrecisionConfig.HIGH, ) do @jit enzyme_gradient(m, x_ra, ps_ra, st_ra) end @@ -60,7 +60,7 @@ # TODO: is zygote off here? @test ∂x_zyg ≈ ∂x_ra atol = 1.0f-2 rtol = 1.0f-2 skip = setup.shift - @test check_approx(∂ps_zyg, ∂ps_ra; atol=1.0f-2, rtol=1.0f-2) skip = setup.shift + @test check_approx(∂ps_zyg, ∂ps_ra; atol = 1.0f-2, rtol = 1.0f-2) skip = setup.shift end end end diff --git a/test/nomad_tests.jl b/test/nomad_tests.jl index d55f155..c340dc4 100644 --- a/test/nomad_tests.jl +++ b/test/nomad_tests.jl @@ -3,20 +3,20 @@ setups = [ ( - u_size=(1, 5), - y_size=(1, 5), - out_size=(1, 5), - approximator=(1, 16, 16, 15), - decoder=(16, 8, 4, 1), - name="Scalar", + u_size = (1, 5), + y_size = (1, 5), + out_size = (1, 5), + approximator = (1, 16, 16, 15), + decoder = (16, 8, 4, 1), + name = "Scalar", ), ( - u_size=(8, 5), - y_size=(2, 5), - out_size=(8, 5), - approximator=(8, 32, 32, 16), - decoder=(18, 16, 8, 8), - name="Vector", + u_size = (8, 5), + y_size = (2, 5), + out_size = (8, 5), + approximator = (8, 32, 32, 16), + decoder = (18, 16, 8, 8), + name = "Vector", ), ] @@ -25,7 +25,7 @@ @testset "$(setup.name)" for setup in setups u = rand(Float32, setup.u_size...) y = rand(Float32, setup.y_size...) - nomad = NOMAD(; approximator=setup.approximator, decoder=setup.decoder) + nomad = NOMAD(; approximator = setup.approximator, decoder = setup.decoder) ps, st = Lux.setup(rng, nomad) @@ -39,8 +39,8 @@ ∂u_zyg, ∂ps_zyg = zygote_gradient(nomad, (u, y), ps, st) ∂u_ra, ∂ps_ra = Reactant.with_config(; - dot_general_precision=PrecisionConfig.HIGH, - convolution_precision=PrecisionConfig.HIGH, + dot_general_precision = PrecisionConfig.HIGH, + convolution_precision = PrecisionConfig.HIGH, ) do @jit enzyme_gradient(nomad, (u_ra, y_ra), ps_ra, st_ra) end @@ -48,7 +48,7 @@ @test ∂u_zyg[1] ≈ ∂u_ra[1] atol = 1.0f-2 rtol = 1.0f-2 @test ∂u_zyg[2] ≈ ∂u_ra[2] atol = 1.0f-2 rtol = 1.0f-2 - @test check_approx(∂ps_zyg, ∂ps_ra; atol=1.0f-2, rtol=1.0f-2) + @test check_approx(∂ps_zyg, ∂ps_ra; atol = 1.0f-2, rtol = 1.0f-2) end end end diff --git a/test/qa_tests.jl b/test/qa_tests.jl index 57c8042..24b4aea 100644 --- a/test/qa_tests.jl +++ b/test/qa_tests.jl @@ -5,23 +5,23 @@ NeuralOperators, :DocTestSetup, :(using Lux, NeuralOperators, Random); - recursive=true, + recursive = true, ) - doctest(NeuralOperators; manual=false) + doctest(NeuralOperators; manual = false) end @testitem "Aqua: Quality Assurance" tags = [:qa] begin using Aqua - Aqua.test_all(NeuralOperators; ambiguities=false) - Aqua.test_ambiguities(NeuralOperators; recursive=false) + Aqua.test_all(NeuralOperators; ambiguities = false) + Aqua.test_ambiguities(NeuralOperators; recursive = false) end @testitem "Explicit Imports: Quality Assurance" tags = [:qa] begin using ExplicitImports, Lux # Skip our own packages - @test check_no_implicit_imports(NeuralOperators; skip=(Base, Core, Lux)) === nothing + @test check_no_implicit_imports(NeuralOperators; skip = (Base, Core, Lux)) === nothing @test check_no_stale_explicit_imports(NeuralOperators) === nothing @test check_no_self_qualified_accesses(NeuralOperators) === nothing @test check_all_explicit_imports_via_owners(NeuralOperators) === nothing diff --git a/test/runtests.jl b/test/runtests.jl index c777622..7b8141d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -9,8 +9,8 @@ const RETESTITEMS_NWORKER_THREADS = parse( @testset "NeuralOperators.jl Tests" begin ReTestItems.runtests( NeuralOperators; - nworkers=1, - nworker_threads=RETESTITEMS_NWORKER_THREADS, - testitem_timeout=3600, + nworkers = 1, + nworker_threads = RETESTITEMS_NWORKER_THREADS, + testitem_timeout = 3600, ) end diff --git a/test/shared_testsetup.jl b/test/shared_testsetup.jl index fd16b02..afe7d2c 100644 --- a/test/shared_testsetup.jl +++ b/test/shared_testsetup.jl @@ -9,7 +9,7 @@ const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "All")) train!(args...; kwargs...) = train!(MSELoss(), AutoZygote(), args...; kwargs...) -function train!(loss, backend, model, ps, st, data; epochs=10) +function train!(loss, backend, model, ps, st, data; epochs = 10) l1 = @jit loss(model, ps, st, first(data)) tstate = Training.TrainState(model, ps, st, Adam(0.01f0))