Add configurable Newton-Schulz iteration dtype#228
Conversation
77cbc63 to
5a20e43
Compare
Greptile SummaryThis PR adds a configurable
Confidence Score: 5/5Safe to merge — defaults are unchanged, the new validation paths are covered by tests, and the dtype/syrk guard is now caught at construction time. The change is additive and backward-compatible. The only finding is a static-analysis code smell in the restructured use_syrk guard where sm_version is assigned inside a branch that can be bypassed, leaving it technically unbound for the short-circuit path — no runtime impact. emerging_optimizers/orthogonalized_optimizers/muon.py — the restructured use_syrk block warrants a quick look to confirm static analysers are satisfied after the suggested fix. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[Muon.__init__ / newton_schulz called] --> B{ns_dtype == 'auto'?}
B -- Yes --> C{fp32_matmul_prec provided?}
C -- Yes explicit param --> D{fp32_matmul_prec == 'medium'?}
C -- No use global state --> E{torch.get_float32_matmul_precision == 'medium'?}
D -- Yes --> F[torch.bfloat16]
D -- No --> G[torch.float32]
E -- Yes --> F
E -- No --> G
B -- No 'float32' --> G
B -- No 'bfloat16' --> F
B -- No 'float16' --> H[torch.float16]
F --> I{use_syrk=True?}
G --> J{use_syrk=True?}
H --> K{use_syrk=True?}
I -- Yes --> L[Use newton_schulz_step_tsyrk]
J -- Yes --> M[WARN: fallback to addmm/baddbmm]
K -- Yes --> M
I -- No --> N[Use newton_schulz_step / batched]
J -- No --> N
K -- No --> N
M --> N
Reviews (2): Last reviewed commit: "Add configurable Newton-Schulz iteration..." | Re-trigger Greptile |
| 1, | ||
| ) | ||
| X = X.to(resolved_ns_dtype) | ||
| logging.log_first_n(logging.INFO, f"Using {resolved_ns_dtype} I/O kernels for Newton-Schulz iteration.", 1) |
There was a problem hiding this comment.
The f-string here pre-formats the message before passing it to
log_first_n, which is at odds with the absl logging convention of using lazy %s args. Although absl-py keys the first-n counter on source location (not message content) so counting is unaffected, the string is always allocated and formatted even on suppressed calls. Prefer passing the dtype as a positional %s argument so formatting is deferred.
| logging.log_first_n(logging.INFO, f"Using {resolved_ns_dtype} I/O kernels for Newton-Schulz iteration.", 1) | |
| logging.log_first_n(logging.INFO, "Using %s I/O kernels for Newton-Schulz iteration.", 1, resolved_ns_dtype) |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| ns_dtype: NSDTypeT = "auto", | ||
| ) -> None: | ||
| if num_ns_steps < 1: | ||
| raise ValueError(f"num_ns_steps must be at least 1, got {num_ns_steps}") |
There was a problem hiding this comment.
Invalid
ns_dtype not caught at construction time
_resolve_ns_dtype raises ValueError for an unrecognized ns_dtype string, but it is only called inside scaled_orthogonalize_fn — a closure executed at step time. A call like Muon(..., ns_dtype="fp16") will construct without error and raise only on the first optimizer step. For the same reason num_ns_steps is validated eagerly above, calling muon_utils._resolve_ns_dtype(ns_dtype) once here (discarding the return value) would catch typos at object creation time rather than mid-training.
Signed-off-by: hyleepp <22672179+hyleepp@users.noreply.github.com>
5a20e43 to
67408f1
Compare
|
Addressed the automated review feedback in 67408f1:
Local checks rerun:
|
|
Thanks for bringing this up. (was on vacation, sorry for the delay) Numerical precision is a complex topic. By design, we only provide very few and safe options. The NS function assert input must be fp32. The matmul precision is only meant for compute type, not storage. The reason we have to cast to bf16 is Pytorch native doesn't really have a lot of "medium" kernel and will fallback to "high" (TF32). We are not incline to provide those kernels (FP32 I/O, BF16 compute) in this repo, so a short cut (but not equivalent) version was provided. It does give warning message and comment in the code explains the difference in more detail. If it turned out to be insufficient, user can always fallback to "high". FP16 will not be supported as there is no pytorch native option for FP32 I/O matmul to use fp16 precision for compute. FP16's narrow exponent range may introduce other issues. If the 3bits mantissa difference between fp16 and bf16 made the difference, it is probably in a dangerous range, i.e. hard to tell 10 is enough. in which case, maybe use FP32 (or MOP optimizer) instead. |
It is measurable that orthogonal quality varies with precision and number of NS steps. i.e. reconstructing the identity, and check how far diagonal values are from 1 and off diagonal values are from 0. The difficult part is, however, determine whether/how it affects LLM training. Things are highly coupled with hyper parameter and all rest of it, so we leave the choice to users. Probably needs to fit a scaling low to get a sense how sensitive orthogonalization quality is to training accuracy. |

Summary
This PR adds an explicit
ns_dtypeoption for Newton-Schulz iterations and exposes it throughMuon.The default is
ns_dtype="auto", which preserves the current behavior:bfloat16for Newton-Schulz state/intermediates whentorch.get_float32_matmul_precision() == "medium"float32Users can now explicitly select:
ns_dtype="float32"ns_dtype="bfloat16"ns_dtype="float16"use_syrk=Trueremains limited to bfloat16; non-bfloat16 dtypes fall back to the addmm/baddbmm path.Motivation
The current implementation always maps the
"medium"matmul-precision path to BF16 I/O Newton-Schulz. That is a solid default, but recent large-scale Muon usage suggests there is value in making the iteration dtype configurable for controlled ablations.In Step 3.5 Flash, the authors describe Muon's reduced-precision polar-factor iteration as numerically sensitive and report that BF16 Polar Express could rarely produce extreme intermediate outliers due to cumulative addition error. Their mitigation was to cast only the Polar Express iteration state/intermediates to FP16 while keeping the rest of training mixed precision:
https://arxiv.org/html/2602.10604v2#S4.SS1.SSS1
This PR does not change defaults. It simply makes this kind of experiment possible without downstream patching.
Local numeric sanity check
On a local CPU run with normalized inputs, explicit FP16 addmm Newton-Schulz had smaller error than BF16 relative to a full FP32 Newton-Schulz reference. This is only a local sanity check, not a CUDA Tensor Core benchmark:
(32, 64), 1-step simple2.38e-32.86e-4(64, 32), 1-step simple2.35e-33.04e-4(128, 128), 1-step simple2.34e-32.95e-4(32, 64), 5-step quintic5.92e-37.49e-4(64, 32), 5-step quintic8.88e-31.05e-3(128, 128), 5-step quintic6.73e-38.59e-4Tests
Local CPU checks:
PYTHONPATH=<repo>:<test-deps> <python-with-torch> tests/test_muon_utils.py --device=cpuuvx ruff check emerging_optimizers/orthogonalized_optimizers/muon_utils.py emerging_optimizers/orthogonalized_optimizers/muon.py tests/test_muon_utils.py tests/test_orthogonalized_optimizer.py<python-with-torch> -m py_compile emerging_optimizers/orthogonalized_optimizers/muon_utils.py emerging_optimizers/orthogonalized_optimizers/muon.py tests/test_muon_utils.py tests/test_orthogonalized_optimizer.pyNote: I could not run the full optimizer test file locally because Triton has no macOS wheel in this environment. The added
test_muon_utils.pycoverage is CPU-only and passed locally.