Add NorMuon optimizer and Cautious Weight Decay#465
Add NorMuon optimizer and Cautious Weight Decay#465kabachuha wants to merge 2 commits intotdrussell:mainfrom
Conversation
| # The projector is a class which contains tensors, so state needs to be explicitly moved to the correct device. | ||
| if 'projector' in state: | ||
| state['projector'].to(p.device) | ||
| # Handle second_momentum_buffer for NorMuon |
There was a problem hiding this comment.
You don't need this, all state is already handled above.
| vnorm = numerator.norm(dim=(-2, -1), keepdim=True) | ||
|
|
||
| # Compute second moment (v_mean) - mean over the last dimension | ||
| v_mean = torch.mean(numerator * numerator, dim=-1, keepdim=True) |
There was a problem hiding this comment.
Everything to do with second moment can be more cleanly implemented by putting it in get_denominator() with a new second_moment_type='columns' option.
| # Apply normalization | ||
| numerator.mul_(step_size_per_element) | ||
|
|
||
| # Compute new norm and rescale to maintain original norm |
There was a problem hiding this comment.
What is this? I don't see it in the paper.
There was a problem hiding this comment.
I see github doesn't show the exact line I commented on, I mean specifically the part about rescaling to maintain original norm.
| numerator.mul_(vnorm / (vnorm_new + group["eps"])) | ||
|
|
||
| # Apply the same scaling as Muon | ||
| step_size *= math.sqrt(max(rows, cols)) |
There was a problem hiding this comment.
Should probably just directly implement Algorithm 1, line 10 in the paper, instead of this. They are very similar. For an orthonormal matrix, the frobenius norm is sqrt(min(rows, cols)). So if the update were exactly orthogonalized, that line in the paper reduces to this code. But it is approximately orthogonalized, not exact, so there is a difference.
OG Muon does it like this code, but better to just exactly stick with what the NorMuon paper says.
|
Okay, thank you for the review! I'll look into simplifying the NorMuon part. To be transparent, this was vibe-coded based on Muon's, Normuon's, and diffusion-pipe codebases. 😅 The integration of CWD is fully organic, however |
|
You can also just leave this PR open. I will want to test out NorMuon myself at some point, and I can make the changes myself when I do that if it's not done already. |
NorMuon and Cautious Weight Decay have recently set two new records on the GPT speedrun repository and proven effective in their respective accompanying science papers. (with the latter technique exhaustively tested with 20,000 GPU hours)
Both methods work out-of-the-box, with minimal modifications to the Muon/GenericOptim class.
Cautious Weight Decay is intended to improve the convergence end result (more stable region), while NorMuon simply enables to reach the given loss faster.
In my experience, NorMuon works slightly slower than Muon per step, but if it achieves better loss, then it's acceptable.
Following ModdedGPT's code, I also added a "schedule" for weight decay - this way it is proportional to the learning rate, meaning it will gradually increase at the beginning of training and then gradually decrease in the end (assuming, learning rate schedule is enabled)
NorMuon: https://arxiv.org/pdf/2510.05491, https://github.com/zichongli5/NorMuon.
Cautious Weight Decay: https://arxiv.org/pdf/2510.12402, kozistr/pytorch_optimizer@99dc4a5, KellerJordan/modded-nanogpt#154.
Closes #464.