Skip to content

Add NorMuon optimizer and Cautious Weight Decay#465

Open
kabachuha wants to merge 2 commits intotdrussell:mainfrom
kabachuha:cwd
Open

Add NorMuon optimizer and Cautious Weight Decay#465
kabachuha wants to merge 2 commits intotdrussell:mainfrom
kabachuha:cwd

Conversation

@kabachuha
Copy link
Copy Markdown
Contributor

NorMuon and Cautious Weight Decay have recently set two new records on the GPT speedrun repository and proven effective in their respective accompanying science papers. (with the latter technique exhaustively tested with 20,000 GPU hours)

Both methods work out-of-the-box, with minimal modifications to the Muon/GenericOptim class.

Cautious Weight Decay is intended to improve the convergence end result (more stable region), while NorMuon simply enables to reach the given loss faster.

In my experience, NorMuon works slightly slower than Muon per step, but if it achieves better loss, then it's acceptable.

Following ModdedGPT's code, I also added a "schedule" for weight decay - this way it is proportional to the learning rate, meaning it will gradually increase at the beginning of training and then gradually decrease in the end (assuming, learning rate schedule is enabled)

NorMuon: https://arxiv.org/pdf/2510.05491, https://github.com/zichongli5/NorMuon.

Cautious Weight Decay: https://arxiv.org/pdf/2510.12402, kozistr/pytorch_optimizer@99dc4a5, KellerJordan/modded-nanogpt#154.

Closes #464.

@kabachuha kabachuha mentioned this pull request Nov 30, 2025
seruva19 added a commit to seruva19/takenoko that referenced this pull request Nov 30, 2025
# The projector is a class which contains tensors, so state needs to be explicitly moved to the correct device.
if 'projector' in state:
state['projector'].to(p.device)
# Handle second_momentum_buffer for NorMuon
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need this, all state is already handled above.

vnorm = numerator.norm(dim=(-2, -1), keepdim=True)

# Compute second moment (v_mean) - mean over the last dimension
v_mean = torch.mean(numerator * numerator, dim=-1, keepdim=True)
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything to do with second moment can be more cleanly implemented by putting it in get_denominator() with a new second_moment_type='columns' option.

# Apply normalization
numerator.mul_(step_size_per_element)

# Compute new norm and rescale to maintain original norm
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this? I don't see it in the paper.

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see github doesn't show the exact line I commented on, I mean specifically the part about rescaling to maintain original norm.

numerator.mul_(vnorm / (vnorm_new + group["eps"]))

# Apply the same scaling as Muon
step_size *= math.sqrt(max(rows, cols))
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should probably just directly implement Algorithm 1, line 10 in the paper, instead of this. They are very similar. For an orthonormal matrix, the frobenius norm is sqrt(min(rows, cols)). So if the update were exactly orthogonalized, that line in the paper reduces to this code. But it is approximately orthogonalized, not exact, so there is a difference.

OG Muon does it like this code, but better to just exactly stick with what the NorMuon paper says.

@kabachuha
Copy link
Copy Markdown
Contributor Author

Okay, thank you for the review! I'll look into simplifying the NorMuon part.

To be transparent, this was vibe-coded based on Muon's, Normuon's, and diffusion-pipe codebases. 😅

The integration of CWD is fully organic, however

@tdrussell
Copy link
Copy Markdown
Owner

You can also just leave this PR open. I will want to test out NorMuon myself at some point, and I can make the changes myself when I do that if it's not done already.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Two new optimization techniques: NorMuon and Cautious Weight Decay

2 participants