Skip to content

newton_schulz silently degenerates (non-orthogonal output) for small-norm inputs due to F.normalize eps floor #229

Description

@yuchenwang3

Describe the bug

newton_schulz is expected to be scale-invariant: for any input scale it should return an (semi-)orthogonal matrix with ‖UVᵀ‖_F ≈ √min(m,n). It does not. The prelude X = F.normalize(x, p=2, dim=(-2,-1), eps=1e-7) divides by eps instead of ‖x‖_F once ‖x‖_F < eps, so the normalized matrix has norm ‖x‖_F/eps ≪ 1. The Newton–Schulz iteration (tuned for singular values ≈ 1) cannot lift it, and the output is a degenerate, non-orthogonal matrix whose norm collapses toward 0 — silently, with no warning or error.

Steps/Code to reproduce bug

Minimal, torch-only (no GPU, runs in seconds) — mirrors the F.normalize(..., eps=1e-7) prelude exactly:

import torch, torch.nn.functional as F
torch.manual_seed(0); torch.set_float32_matmul_precision("high")

def newton_schulz(x, steps=6, eps=1e-7):
    a, b, c = 3.4445, -4.7750, 2.0315
    X = F.normalize(x, p=2, dim=(-2, -1), eps=eps)   # the line under scrutiny
    for _ in range(steps):
        A = X @ X.transpose(-2, -1)
        X = a * X + (b * A + c * (A @ A)) @ X
    return X

m = 2048; base = torch.randn(m, m); base /= base.norm()   # unit-Frobenius direction
for N in (1e0, 1e-6, 1e-8, 1e-9, 1e-10):
    print(f"||in||_F={N:>7.0e}  ||out||_F={newton_schulz(base*N).norm():8.3f}  (ideal ~{m**0.5:.1f})")

Output:

||in||_F=  1e+00  ||out||_F=  39.929   (ideal ~45.3)
||in||_F=  1e-06  ||out||_F=  39.929
||in||_F=  1e-08  ||out||_F=  40.118
||in||_F=  1e-09  ||out||_F=  16.135   <- DEGENERATE
||in||_F=  1e-10  ||out||_F=   1.669   <- DEGENERATE

Same against the package itself (from emerging_optimizers.orthogonalized_optimizers.muon_utils import newton_schulz, coefficient_type="polar_express"): 1e0→44.93, 1e-8→42.11, 1e-10→1.83.

Expected behavior

A scale-invariant orthogonalizer should return ‖out‖_F ≈ √min(m,n) regardless of input scale (the plateau value), instead of collapsing once ‖x‖_F underflows the eps floor. At minimum, it should warn/error rather than silently emit a non-orthogonal result.

Additional context

  • Location: emerging_optimizers/orthogonalized_optimizers/muon_utils.py, newton_schulz(), eps: float = 1e-7 used in F.normalize.
  • Suggested fix: use eps=1e-30 purely as a divide-by-zero guard, or normalize by ‖x‖_F.clamp_min(tiny) with tiny near the dtype floor, or detect/warn when ‖x‖_F < eps.
  • Why it matters: combined with a framework that applies gradient-norm clipping to the Muon param group (e.g. Megatron-LM ChainedOptimizer), the clip coefficient scales per-matrix gradients below this floor → Newton–Schulz silently emits degenerate updates → training stalls with the forward/loss looking completely normal (very hard to diagnose). Filed separately on Megatron-LM; the two interact.
  • Environment: Emerging-Optimizers 0.4.0a0, PyTorch 2.11+cu128 (repro above: any torch, CPU).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions