Skip to content

Add configurable Newton-Schulz iteration dtype#228

Closed
hyleepp wants to merge 1 commit into
NVIDIA-NeMo:mainfrom
hyleepp:codex/ns-dtype-fp16
Closed

Add configurable Newton-Schulz iteration dtype#228
hyleepp wants to merge 1 commit into
NVIDIA-NeMo:mainfrom
hyleepp:codex/ns-dtype-fp16

Conversation

@hyleepp

@hyleepp hyleepp commented Jun 15, 2026

Copy link
Copy Markdown

Summary

This PR adds an explicit ns_dtype option for Newton-Schulz iterations and exposes it through Muon.

The default is ns_dtype="auto", which preserves the current behavior:

  • use bfloat16 for Newton-Schulz state/intermediates when torch.get_float32_matmul_precision() == "medium"
  • otherwise keep the Newton-Schulz iteration in float32

Users can now explicitly select:

  • ns_dtype="float32"
  • ns_dtype="bfloat16"
  • ns_dtype="float16"

use_syrk=True remains limited to bfloat16; non-bfloat16 dtypes fall back to the addmm/baddbmm path.

Motivation

The current implementation always maps the "medium" matmul-precision path to BF16 I/O Newton-Schulz. That is a solid default, but recent large-scale Muon usage suggests there is value in making the iteration dtype configurable for controlled ablations.

In Step 3.5 Flash, the authors describe Muon's reduced-precision polar-factor iteration as numerically sensitive and report that BF16 Polar Express could rarely produce extreme intermediate outliers due to cumulative addition error. Their mitigation was to cast only the Polar Express iteration state/intermediates to FP16 while keeping the rest of training mixed precision:

https://arxiv.org/html/2602.10604v2#S4.SS1.SSS1

This PR does not change defaults. It simply makes this kind of experiment possible without downstream patching.

Local numeric sanity check

On a local CPU run with normalized inputs, explicit FP16 addmm Newton-Schulz had smaller error than BF16 relative to a full FP32 Newton-Schulz reference. This is only a local sanity check, not a CUDA Tensor Core benchmark:

shape / schedule BF16 addmm rel-Fro error FP16 addmm rel-Fro error
(32, 64), 1-step simple 2.38e-3 2.86e-4
(64, 32), 1-step simple 2.35e-3 3.04e-4
(128, 128), 1-step simple 2.34e-3 2.95e-4
(32, 64), 5-step quintic 5.92e-3 7.49e-4
(64, 32), 5-step quintic 8.88e-3 1.05e-3
(128, 128), 5-step quintic 6.73e-3 8.59e-4

Tests

Local CPU checks:

  • PYTHONPATH=<repo>:<test-deps> <python-with-torch> tests/test_muon_utils.py --device=cpu
  • uvx ruff check emerging_optimizers/orthogonalized_optimizers/muon_utils.py emerging_optimizers/orthogonalized_optimizers/muon.py tests/test_muon_utils.py tests/test_orthogonalized_optimizer.py
  • <python-with-torch> -m py_compile emerging_optimizers/orthogonalized_optimizers/muon_utils.py emerging_optimizers/orthogonalized_optimizers/muon.py tests/test_muon_utils.py tests/test_orthogonalized_optimizer.py

Note: I could not run the full optimizer test file locally because Triton has no macOS wheel in this environment. The added test_muon_utils.py coverage is CPU-only and passed locally.

@copy-pr-bot

copy-pr-bot Bot commented Jun 15, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@hyleepp hyleepp force-pushed the codex/ns-dtype-fp16 branch from 77cbc63 to 5a20e43 Compare June 15, 2026 14:58
@greptile-apps

greptile-apps Bot commented Jun 15, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR adds a configurable ns_dtype parameter to newton_schulz, newton_schulz_tp, and Muon, allowing explicit control over the dtype used for Newton-Schulz iteration state and intermediates ("auto", "float32", "bfloat16", "float16"). The default "auto" preserves existing behavior, so there is no change to defaults.

  • _validate_ns_dtype and _resolve_ns_dtype are introduced in muon_utils.py to centralize dtype resolution; Muon.__init__ now validates ns_dtype eagerly and guards the use_syrk compatibility check at construction time (consistent with the existing Triton/SM checks).
  • Tests cover all explicit dtypes, "auto" precision-preservation, invalid-dtype errors at both the utility and optimizer level, and an end-to-end Muon step with ns_dtype=\"float16\".

Confidence Score: 5/5

Safe to merge — defaults are unchanged, the new validation paths are covered by tests, and the dtype/syrk guard is now caught at construction time.

The change is additive and backward-compatible. The only finding is a static-analysis code smell in the restructured use_syrk guard where sm_version is assigned inside a branch that can be bypassed, leaving it technically unbound for the short-circuit path — no runtime impact.

emerging_optimizers/orthogonalized_optimizers/muon.py — the restructured use_syrk block warrants a quick look to confirm static analysers are satisfied after the suggested fix.

Important Files Changed

Filename Overview
emerging_optimizers/orthogonalized_optimizers/muon_utils.py Adds NSDTypeT, _validate_ns_dtype, and _resolve_ns_dtype; threads ns_dtype through newton_schulz and newton_schulz_tp. Logic is correct; the use_syrk fallback warning paths are properly handled. The log_first_n call now uses lazy %s formatting as previously requested.
emerging_optimizers/orthogonalized_optimizers/muon.py Adds ns_dtype parameter to Muon.init with eager validation and syrk-compatibility guard. One code-quality issue: sm_version is potentially unbound in the restructured use_syrk guard (safe at runtime via short-circuit but flagged by static analysers).
tests/test_muon_utils.py Adds targeted tests for explicit dtypes, auto precision preservation against both medium and highest matmul precision, and invalid-dtype error handling. Coverage is solid for the public API surface.
tests/test_orthogonalized_optimizer.py Adds two Muon-level tests: end-to-end dtype propagation with ns_dtype=float16 and init-time ValueError for an invalid dtype. Both tests are well-scoped and verify the previously requested eager validation.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[Muon.__init__ / newton_schulz called] --> B{ns_dtype == 'auto'?}
    B -- Yes --> C{fp32_matmul_prec provided?}
    C -- Yes explicit param --> D{fp32_matmul_prec == 'medium'?}
    C -- No use global state --> E{torch.get_float32_matmul_precision == 'medium'?}
    D -- Yes --> F[torch.bfloat16]
    D -- No --> G[torch.float32]
    E -- Yes --> F
    E -- No --> G
    B -- No 'float32' --> G
    B -- No 'bfloat16' --> F
    B -- No 'float16' --> H[torch.float16]
    F --> I{use_syrk=True?}
    G --> J{use_syrk=True?}
    H --> K{use_syrk=True?}
    I -- Yes --> L[Use newton_schulz_step_tsyrk]
    J -- Yes --> M[WARN: fallback to addmm/baddbmm]
    K -- Yes --> M
    I -- No --> N[Use newton_schulz_step / batched]
    J -- No --> N
    K -- No --> N
    M --> N
Loading

Reviews (2): Last reviewed commit: "Add configurable Newton-Schulz iteration..." | Re-trigger Greptile

1,
)
X = X.to(resolved_ns_dtype)
logging.log_first_n(logging.INFO, f"Using {resolved_ns_dtype} I/O kernels for Newton-Schulz iteration.", 1)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 The f-string here pre-formats the message before passing it to log_first_n, which is at odds with the absl logging convention of using lazy %s args. Although absl-py keys the first-n counter on source location (not message content) so counting is unaffected, the string is always allocated and formatted even on suppressed calls. Prefer passing the dtype as a positional %s argument so formatting is deferred.

Suggested change
logging.log_first_n(logging.INFO, f"Using {resolved_ns_dtype} I/O kernels for Newton-Schulz iteration.", 1)
logging.log_first_n(logging.INFO, "Using %s I/O kernels for Newton-Schulz iteration.", 1, resolved_ns_dtype)

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment on lines +91 to 94
ns_dtype: NSDTypeT = "auto",
) -> None:
if num_ns_steps < 1:
raise ValueError(f"num_ns_steps must be at least 1, got {num_ns_steps}")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Invalid ns_dtype not caught at construction time

_resolve_ns_dtype raises ValueError for an unrecognized ns_dtype string, but it is only called inside scaled_orthogonalize_fn — a closure executed at step time. A call like Muon(..., ns_dtype="fp16") will construct without error and raise only on the first optimizer step. For the same reason num_ns_steps is validated eagerly above, calling muon_utils._resolve_ns_dtype(ns_dtype) once here (discarding the return value) would catch typos at object creation time rather than mid-training.

Signed-off-by: hyleepp <22672179+hyleepp@users.noreply.github.com>
@hyleepp hyleepp force-pushed the codex/ns-dtype-fp16 branch from 5a20e43 to 67408f1 Compare June 15, 2026 15:56
@hyleepp

hyleepp commented Jun 15, 2026

Copy link
Copy Markdown
Author

Addressed the automated review feedback in 67408f1:

  • Added eager validation for invalid ns_dtype in Muon.__init__.
  • Disabled use_syrk at init time when the resolved Newton-Schulz dtype is not bfloat16, matching the existing Triton/SM fallback pattern.
  • Switched the log_first_n message to lazy formatting.
  • Amended the commit with a DCO Signed-off-by line using the GitHub noreply identity.

Local checks rerun:

  • ruff check on the touched files.
  • tests/test_muon_utils.py --device=cpu with 79 tests run, 2 skipped for unavailable Triton/GPU kernels.

@skyw

skyw commented Jun 21, 2026

Copy link
Copy Markdown
Contributor

Thanks for bringing this up. (was on vacation, sorry for the delay)

Numerical precision is a complex topic. By design, we only provide very few and safe options. The NS function assert input must be fp32. The matmul precision is only meant for compute type, not storage. The reason we have to cast to bf16 is Pytorch native doesn't really have a lot of "medium" kernel and will fallback to "high" (TF32). We are not incline to provide those kernels (FP32 I/O, BF16 compute) in this repo, so a short cut (but not equivalent) version was provided. It does give warning message and comment in the code explains the difference in more detail. If it turned out to be insufficient, user can always fallback to "high".

FP16 will not be supported as there is no pytorch native option for FP32 I/O matmul to use fp16 precision for compute. FP16's narrow exponent range may introduce other issues. If the 3bits mantissa difference between fp16 and bf16 made the difference, it is probably in a dangerous range, i.e. hard to tell 10 is enough. in which case, maybe use FP32 (or MOP optimizer) instead.

@hyleepp

hyleepp commented Jun 22, 2026

Copy link
Copy Markdown
Author

Thanks for bringing this up. (was on vacation, sorry for the delay)

Numerical precision is a complex topic. By design, we only provide very few and safe options. The NS function assert input must be fp32. The matmul precision is only meant for compute type, not storage. The reason we have to cast to bf16 is Pytorch native doesn't really have a lot of "medium" kernel and will fallback to "high" (TF32). We are not incline to provide those kernels (FP32 I/O, BF16 compute) in this repo, so a short cut (but not equivalent) version was provided. It does give warning message and comment in the code explains the difference in more detail. If it turned out to be insufficient, user can always fallback to "high".

FP16 will not be supported as there is no pytorch native option for FP32 I/O matmul to use fp16 precision for compute. FP16's narrow exponent range may introduce other issues. If the 3bits mantissa difference between fp16 and bf16 made the difference, it is probably in a dangerous range, i.e. hard to tell 10 is enough. in which case, maybe use FP32 (or MOP optimizer) instead.

Many thanks for the detailed reply. I agree that if the 3-bit mantissa difference between fp16 and bf16 is enough to matter, then the computation is probably already in a rather fragile regime.

My initial motivation was that, after reading the paper, I tried to simulate the behavior under different data types and observed some strange spikes in PE6. I also ran a few small-scale experiments on MoE models, but did not find any noticeable difference.

So my question is mostly whether using fewer Newton-Schulz iterations with bf16 could introduce subtle numerical issues that may not be visible at small scale, but become more likely to appear in larger models. I was also curious about this because DeepSeek appears to use 10 iterations in their V4 Pro setting.

That said, I agree that it is probably better to keep the current behavior unchanged for now, unless someone can provide more concrete evidence that this causes real issues in practice.
image

@skyw

skyw commented Jun 22, 2026

Copy link
Copy Markdown
Contributor

So my question is mostly whether using fewer Newton-Schulz iterations with bf16 could introduce subtle numerical issues that may not be visible at small scale, but become more likely to appear in larger models. I was also curious about this because DeepSeek appears to use 10 iterations in their V4 Pro setting.

It is measurable that orthogonal quality varies with precision and number of NS steps. i.e. reconstructing the identity, and check how far diagonal values are from 1 and off diagonal values are from 0. The difficult part is, however, determine whether/how it affects LLM training. Things are highly coupled with hyper parameter and all rest of it, so we leave the choice to users. Probably needs to fit a scaling low to get a sense how sensitive orthogonalization quality is to training accuracy.

@skyw skyw closed this Jun 22, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants