diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index d058286..f416b02 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -37,6 +37,8 @@ jobs: version: ${{ matrix.version }} arch: ${{ matrix.arch }} - uses: julia-actions/cache@v2 + - name: registry_add + run: julia -e 'using Pkg; Pkg.Registry.add(Pkg.Registry.RegistrySpec(url="https://github.com/MurrellGroup/MurrellGroupRegistry"))' - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 - uses: julia-actions/julia-processcoverage@v1 @@ -62,6 +64,8 @@ jobs: shell: julia --project=docs --color=yes {0} run: | using Pkg + Pkg.Registry.add(Pkg.Registry.RegistrySpec(url="https://github.com/JuliaRegistries/General")) + Pkg.Registry.add(Pkg.Registry.RegistrySpec(url="https://github.com/MurrellGroup/MurrellGroupRegistry")) Pkg.develop(PackageSpec(path=pwd())) Pkg.instantiate() - uses: julia-actions/julia-buildpkg@v1 diff --git a/Project.toml b/Project.toml index aec8d16..1bc566c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,34 +1,20 @@ name = "CannotWaitForTheseOptimisers" uuid = "16124dda-d9fe-413b-a880-e3f4df3aa341" -authors = ["murrellb and contributors"] version = "0.1.1" +authors = ["murrellb and contributors"] + +[workspace] +projects = ["test", "docs"] [deps] +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +MatrixSign = "a25fa8c1-e8fe-40fb-8be2-d139e369b1d5" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [compat] -ChainRulesCore = "1" -Functors = "0.5" -LinearAlgebra = "1" +Functors = "0.5.2" +MatrixSign = "0.0.1" Optimisers = "0.4" Random = "1" -StaticArrays = "1" -Statistics = "1" -Test = "1" -Zygote = "0.6" julia = "1.10" - -[extras] -Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" -LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" -Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" - -[targets] -test = ["Test", "ChainRulesCore", "Functors", "LinearAlgebra", "Optimisers", "StaticArrays", "Statistics", "Zygote"] diff --git a/docs/src/index.md b/docs/src/index.md index 17345db..7f028ab 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -12,3 +12,9 @@ Documentation for [CannotWaitForTheseOptimisers](https://github.com/MurrellGroup ```@autodocs Modules = [CannotWaitForTheseOptimisers] ``` + +## Experimental + +```@autodocs +Modules = [CannotWaitForTheseOptimisers.Experimental] +``` \ No newline at end of file diff --git a/src/CannotWaitForTheseOptimisers.jl b/src/CannotWaitForTheseOptimisers.jl index fa5bf9c..8c01925 100644 --- a/src/CannotWaitForTheseOptimisers.jl +++ b/src/CannotWaitForTheseOptimisers.jl @@ -5,6 +5,8 @@ import Optimisers: OptimiserChain, AbstractRule, Leaf, adjust, adjust!, _adjust, include("rules.jl") include("adjust.jl") +include("setup.jl") +include("Experimental/Experimental.jl") export Muon, Apollo, NormGrowthCap, GradNormControl, AdaptiveGradNormControl diff --git a/src/Experimental/Experimental.jl b/src/Experimental/Experimental.jl new file mode 100644 index 0000000..3f447da --- /dev/null +++ b/src/Experimental/Experimental.jl @@ -0,0 +1,13 @@ +module Experimental + +include("Muon.jl") +export Muon + +using Base: IdSet +export IdSet + +using Functors: fcollect +findnodes(pred::Function, x) = filter(pred, fcollect(x)) +export findnodes + +end diff --git a/src/Experimental/Muon.jl b/src/Experimental/Muon.jl new file mode 100644 index 0000000..db4edf7 --- /dev/null +++ b/src/Experimental/Muon.jl @@ -0,0 +1,58 @@ +using MatrixSign +using Optimisers: AbstractRule, @lazy, @.. +import Optimisers: init, apply!, adjust! + +nonfirstdims(x, dims=ndims(x)) = prod(size(x)[2:dims]) +nonfirstdims(x, ::Nothing) = nonfirstdims(x) + +""" + Muon(η = 0.02, μ = 0.95, λ = 0.01; dims = nothing) + Muon(; [eta, mu, lambda, dims]) + +Muon - MomentUm Orthogonalized by Newton-schulz (https://github.com/KellerJordan/Muon) + +Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-processing step, +in which each 2D parameter's update is replaced with the nearest orthogonal matrix using Newton-Schulz iteration. + +# Parameters +- Learning rate (`η == eta`): Amount by which gradients are discounted before updating the weights +- Momentum (`μ == mu`): Controls the acceleration of gradient descent in the prominent direction +- Weight decay (`λ == lambda`): Controls the strength of ``L_2`` regularisation. +- Keyword `dims`: Dimensions to orthogonalize. If `nothing`, then trailing dimensions get flattened + into the second dimension. If `dims < ndims(x)`, remaining dimensions are orthogonalized independently. + +Note: Works best with large batch sizes and may not be suitable for fine-tuning. +In nanoGPT speedrun experiments, Muon is used for the internal layer >2D weights, and AdamW is used for the 1D weights, embeddings, and heads. + +`Optimisers.adjust!(optimiser_state, η::Real)` will adjust the fallback optimizer's `eta` to `η * (opt.eta / eta)`, and Muon's `eta` to `η`, preserving their ratio, +but `Optimisers.adjust!(optimiser, eta = η)` will only adjust Muon's learning rate (allowing you to adjust the fallback optimizer's learning rate separately). +""" +@kwdef struct Muon <: AbstractRule + eta = 0.02 + mu = 0.95 + lambda = 0.01 + dims = nothing +end + +init(::Muon, x::AbstractArray) = zero(x) + +function apply!( + (; eta, mu, lambda, dims)::Muon, + state, x::AbstractArray{T}, dx +) where T + η, μ, λ = T(eta), T(mu), T(lambda) + # update momentum + @.. state = μ * state + (1-μ) * dx + # Nesterov update fed to msign + U = @. μ * state + (1-μ) * dx + # orthogonalize + Ot = msign!( + reshape(U, size(U, 1), nonfirstdims(U, dims), :), + steps=5, fused=3) + # post shape factor + s = √max(1, T(size(Ot, 1) / size(Ot, 2))) + dx′ = @lazy η * (Ot * s + λ * x) # decoupled WD, step will subtract dx′ + return state, dx′ +end + +adjust!(r::Muon, η::Real) = adjust!(r, eta = η) diff --git a/src/setup.jl b/src/setup.jl new file mode 100644 index 0000000..af9313a --- /dev/null +++ b/src/setup.jl @@ -0,0 +1,26 @@ +# See https://github.com/FluxML/Optimisers.jl/pull/204 + +setup(rule::AbstractRule, model) = setup(Returns(rule), model) +function setup(fun::Function, model) + cache = IdDict() + tree = _setup(fun, model; cache) + isempty(cache) && @warn "setup found no trainable parameters in this model" + tree +end + +# _setup is almost fmapstructure, but needs a _trainable_walk, and a cache which ignores numbers etc. +function _setup(fun::Function, x; cache) + haskey(cache, x) && return cache[x] + if isnumeric(x) + rule = fun(x)::AbstractRule + ℓ = Leaf(rule, init(rule, x)) + if isbits(x) + cache[nothing] = nothing # just to disable the warning + ℓ + else + cache[x] = ℓ + end + else + mapvalue(xᵢ -> _setup(fun, xᵢ; cache), _trainable(x)) + end +end \ No newline at end of file diff --git a/test/Project.toml b/test/Project.toml new file mode 100644 index 0000000..b1db398 --- /dev/null +++ b/test/Project.toml @@ -0,0 +1,11 @@ +[deps] +CannotWaitForTheseOptimisers = "16124dda-d9fe-413b-a880-e3f4df3aa341" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"