Skip to content

Add option to normalize NS input in double precision#238

Open
skyw wants to merge 5 commits into
mainfrom
skyw/handle_tiny_values_in_muon
Open

Add option to normalize NS input in double precision#238
skyw wants to merge 5 commits into
mainfrom
skyw/handle_tiny_values_in_muon

Conversation

@skyw

@skyw skyw commented Jun 26, 2026

Copy link
Copy Markdown
Contributor

"if" clause on calculated norm would trigger device to host sync, trying to do it on device will take multiple path, in which case none of them are better than just use double. A custom kernel can do it but is out of scope.

No zero division guard for fp64 path as we don't imagine a case that LLM training is done in fp64. Square of fp32 value won't underflow fp64.

skyw added 2 commits June 26, 2026 15:02
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
@copy-pr-bot

copy-pr-bot Bot commented Jun 26, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@greptile-apps

greptile-apps Bot commented Jun 26, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR adds a normalize_in_double flag to newton_schulz and distributed_normalize_p2 to handle gradient tensors whose Frobenius norm underflows to zero in float32. When enabled, the squared-sum accumulation (or norm computation via torch.linalg.vector_norm) is performed in float64 before being cast back to float32 for the division, avoiding the underflow without requiring a CPU–GPU sync.

  • Non-distributed path: Uses torch.linalg.vector_norm(..., dtype=torch.float64) with keepdim=True; eps is silently dropped in this branch and a warning is emitted (see comment).
  • Distributed path: Casts x to float64 before squaring and summing, then converts the sqrt back to x.dtype; clamp_min_ (in-place) is applied only when normalize_in_double=False.
  • Tests: Replaces the old scale-invariance regression with two new tests targeting the double-precision path, including a case where the float32 norm genuinely underflows.

Confidence Score: 5/5

Safe to merge; the double-precision normalization logic is correct and the in-place clamp guard in the distributed path works as intended.

The normalization math is sound for both the distributed and non-distributed paths. The only notable issue is the unthrottled warning call, which does not affect correctness.

The logging.warning call in muon_utils.py (non-distributed normalize_in_double branch) should be rate-limited before merging to avoid log spam during training.

Important Files Changed

Filename Overview
emerging_optimizers/orthogonalized_optimizers/muon_utils.py Adds normalize_in_double option to avoid float32 underflow in tiny-norm inputs; the non-distributed path emits an unthrottled logging.warning on every call, unlike the existing log_first_n pattern in the same function.
tests/test_muon_utils.py Replaces scale-invariance test with two new tests covering the normalize_in_double path; test_normalization_scale_invariant at exp2=-20 and exp2=-40 don't actually exercise the underflow branch since those norms are well above eps.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[newton_schulz called] --> B{tp_group?}
    B -- yes --> C[distributed_normalize_p2]
    C --> D{normalize_in_double?}
    D -- yes --> E["x_sq = x.double()\nsum in float64\nall_reduce\nnorm = sqrt.to(float32)\nx / norm"]
    D -- no --> F["x_sq = x\nsum in float32\nall_reduce\nnorm = sqrt.to(float32)\nnorm.clamp_min_(eps)\nx / norm"]
    B -- no --> G{normalize_in_double?}
    G -- no --> H["F.normalize(x, p=2,\ndim=(-2,-1), eps=eps)"]
    G -- yes --> I["logging.warning (every call)\nnorm = vector_norm(x,\ndtype=float64).to(float32)\nx / norm"]
    E --> J[NS iterations]
    F --> J
    H --> J
    I --> J
    J --> K[return X]
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
    A[newton_schulz called] --> B{tp_group?}
    B -- yes --> C[distributed_normalize_p2]
    C --> D{normalize_in_double?}
    D -- yes --> E["x_sq = x.double()\nsum in float64\nall_reduce\nnorm = sqrt.to(float32)\nx / norm"]
    D -- no --> F["x_sq = x\nsum in float32\nall_reduce\nnorm = sqrt.to(float32)\nnorm.clamp_min_(eps)\nx / norm"]
    B -- no --> G{normalize_in_double?}
    G -- no --> H["F.normalize(x, p=2,\ndim=(-2,-1), eps=eps)"]
    G -- yes --> I["logging.warning (every call)\nnorm = vector_norm(x,\ndtype=float64).to(float32)\nx / norm"]
    E --> J[NS iterations]
    F --> J
    H --> J
    I --> J
    J --> K[return X]
Loading

Reviews (3): Last reviewed commit: "bug fix" | Re-trigger Greptile

Comment thread emerging_optimizers/orthogonalized_optimizers/muon_utils.py Outdated
Comment thread emerging_optimizers/orthogonalized_optimizers/muon_utils.py Outdated
Comment thread tests/test_muon_utils.py
skyw added 2 commits June 26, 2026 16:09
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Comment thread emerging_optimizers/orthogonalized_optimizers/muon_utils.py
Signed-off-by: Hao Wu <skyw@nvidia.com>
@skyw skyw requested a review from FDecaYed June 26, 2026 23:21
@skyw

skyw commented Jun 26, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test 854dfd9

@skyw skyw changed the title Handle tiny values in muon better Add option to normalize NS input in double precision Jun 26, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant