Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ jobs:
version: ${{ matrix.version }}
arch: ${{ matrix.arch }}
- uses: julia-actions/cache@v2
- name: registry_add
run: julia -e 'using Pkg; Pkg.Registry.add(Pkg.Registry.RegistrySpec(url="https://github.com/MurrellGroup/MurrellGroupRegistry"))'
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1
- uses: julia-actions/julia-processcoverage@v1
Expand All @@ -62,6 +64,8 @@ jobs:
shell: julia --project=docs --color=yes {0}
run: |
using Pkg
Pkg.Registry.add(Pkg.Registry.RegistrySpec(url="https://github.com/JuliaRegistries/General"))
Pkg.Registry.add(Pkg.Registry.RegistrySpec(url="https://github.com/MurrellGroup/MurrellGroupRegistry"))
Pkg.develop(PackageSpec(path=pwd()))
Pkg.instantiate()
- uses: julia-actions/julia-buildpkg@v1
Expand Down
30 changes: 8 additions & 22 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,34 +1,20 @@
name = "CannotWaitForTheseOptimisers"
uuid = "16124dda-d9fe-413b-a880-e3f4df3aa341"
authors = ["murrellb <murrellb@gmail.com> and contributors"]
version = "0.1.1"
authors = ["murrellb <murrellb@gmail.com> and contributors"]

[workspace]
projects = ["test", "docs"]

[deps]
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
MatrixSign = "a25fa8c1-e8fe-40fb-8be2-d139e369b1d5"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[compat]
ChainRulesCore = "1"
Functors = "0.5"
LinearAlgebra = "1"
Functors = "0.5.2"
MatrixSign = "0.0.1"
Optimisers = "0.4"
Random = "1"
StaticArrays = "1"
Statistics = "1"
Test = "1"
Zygote = "0.6"
julia = "1.10"

[extras]
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test", "ChainRulesCore", "Functors", "LinearAlgebra", "Optimisers", "StaticArrays", "Statistics", "Zygote"]
6 changes: 6 additions & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,9 @@ Documentation for [CannotWaitForTheseOptimisers](https://github.com/MurrellGroup
```@autodocs
Modules = [CannotWaitForTheseOptimisers]
```

## Experimental

```@autodocs
Modules = [CannotWaitForTheseOptimisers.Experimental]
```
2 changes: 2 additions & 0 deletions src/CannotWaitForTheseOptimisers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import Optimisers: OptimiserChain, AbstractRule, Leaf, adjust, adjust!, _adjust,

include("rules.jl")
include("adjust.jl")
include("setup.jl")
include("Experimental/Experimental.jl")

export Muon, Apollo, NormGrowthCap, GradNormControl, AdaptiveGradNormControl

Expand Down
13 changes: 13 additions & 0 deletions src/Experimental/Experimental.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
module Experimental

include("Muon.jl")
export Muon

using Base: IdSet
export IdSet

using Functors: fcollect
findnodes(pred::Function, x) = filter(pred, fcollect(x))
export findnodes

end
58 changes: 58 additions & 0 deletions src/Experimental/Muon.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
using MatrixSign
using Optimisers: AbstractRule, @lazy, @..
import Optimisers: init, apply!, adjust!

nonfirstdims(x, dims=ndims(x)) = prod(size(x)[2:dims])
nonfirstdims(x, ::Nothing) = nonfirstdims(x)

"""
Muon(η = 0.02, μ = 0.95, λ = 0.01; dims = nothing)
Muon(; [eta, mu, lambda, dims])

Muon - MomentUm Orthogonalized by Newton-schulz (https://github.com/KellerJordan/Muon)

Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-processing step,
in which each 2D parameter's update is replaced with the nearest orthogonal matrix using Newton-Schulz iteration.

# Parameters
- Learning rate (`η == eta`): Amount by which gradients are discounted before updating the weights
- Momentum (`μ == mu`): Controls the acceleration of gradient descent in the prominent direction
- Weight decay (`λ == lambda`): Controls the strength of ``L_2`` regularisation.
- Keyword `dims`: Dimensions to orthogonalize. If `nothing`, then trailing dimensions get flattened
into the second dimension. If `dims < ndims(x)`, remaining dimensions are orthogonalized independently.

Note: Works best with large batch sizes and may not be suitable for fine-tuning.
In nanoGPT speedrun experiments, Muon is used for the internal layer >2D weights, and AdamW is used for the 1D weights, embeddings, and heads.

`Optimisers.adjust!(optimiser_state, η::Real)` will adjust the fallback optimizer's `eta` to `η * (opt.eta / eta)`, and Muon's `eta` to `η`, preserving their ratio,
but `Optimisers.adjust!(optimiser, eta = η)` will only adjust Muon's learning rate (allowing you to adjust the fallback optimizer's learning rate separately).
"""
@kwdef struct Muon <: AbstractRule
eta = 0.02
mu = 0.95
lambda = 0.01
dims = nothing
end

init(::Muon, x::AbstractArray) = zero(x)

function apply!(
(; eta, mu, lambda, dims)::Muon,
state, x::AbstractArray{T}, dx
) where T
η, μ, λ = T(eta), T(mu), T(lambda)
# update momentum
@.. state = μ * state + (1-μ) * dx
# Nesterov update fed to msign
U = @. μ * state + (1-μ) * dx
# orthogonalize
Ot = msign!(
reshape(U, size(U, 1), nonfirstdims(U, dims), :),
steps=5, fused=3)
# post shape factor
s = √max(1, T(size(Ot, 1) / size(Ot, 2)))
dx′ = @lazy η * (Ot * s + λ * x) # decoupled WD, step will subtract dx′
return state, dx′
end

adjust!(r::Muon, η::Real) = adjust!(r, eta = η)
26 changes: 26 additions & 0 deletions src/setup.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# See https://github.com/FluxML/Optimisers.jl/pull/204

setup(rule::AbstractRule, model) = setup(Returns(rule), model)
function setup(fun::Function, model)
cache = IdDict()
tree = _setup(fun, model; cache)
isempty(cache) && @warn "setup found no trainable parameters in this model"
tree
end

# _setup is almost fmapstructure, but needs a _trainable_walk, and a cache which ignores numbers etc.
function _setup(fun::Function, x; cache)
haskey(cache, x) && return cache[x]
if isnumeric(x)
rule = fun(x)::AbstractRule
ℓ = Leaf(rule, init(rule, x))
if isbits(x)
cache[nothing] = nothing # just to disable the warning
else
cache[x] = ℓ
end
else
mapvalue(xᵢ -> _setup(fun, xᵢ; cache), _trainable(x))
end
end
11 changes: 11 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
[deps]
CannotWaitForTheseOptimisers = "16124dda-d9fe-413b-a880-e3f4df3aa341"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"