diff --git a/emerging_optimizers/orthogonalized_optimizers/muon_utils.py b/emerging_optimizers/orthogonalized_optimizers/muon_utils.py index b29e02b..521fc71 100644 --- a/emerging_optimizers/orthogonalized_optimizers/muon_utils.py +++ b/emerging_optimizers/orthogonalized_optimizers/muon_utils.py @@ -141,7 +141,7 @@ def newton_schulz( steps: int, coefficient_type: NSCoeffT = "quintic", custom_coefficient_sets: list[tuple[float, float, float]] | None = None, - eps: float = 1e-7, + eps: float = 1e-15, transpose: bool | None = None, tp_group: torch.distributed.ProcessGroup | None = None, use_syrk: bool = False, @@ -192,7 +192,9 @@ def newton_schulz( if transpose: x = x.mT - # Ensure spectral norm is at most 1 + # Ensure spectral norm is at most 1. + # NOTE: ``eps`` is a divide-by-zero guard; it must stay well below any realistic ``||x||_F`` + # yet remain fp32-safe when squared. See issue #229. if tp_group is not None: X = distributed_normalize_p2(x, eps, tp_group) else: diff --git a/tests/test_muon_utils.py b/tests/test_muon_utils.py index ac58f55..dbc56e6 100644 --- a/tests/test_muon_utils.py +++ b/tests/test_muon_utils.py @@ -120,6 +120,29 @@ def test_newtonschulz5_close_to_reference(self, dim1, dim2): rtol=1e-7, ) + @parameterized.parameters(1e-2, 1e-6, 1e-9, 1e-12) + def test_newtonschulz_small_eps(self, scale): + """Orthogonalization depends only on direction, so scaling the input must not change the output. + + Regression test for issue #229: a too-large ``eps`` in the internal ``F.normalize`` divides + small-norm inputs by ``eps`` instead of their norm, silently degenerating the output. The + orthogonalized result for ``x`` and ``scale * x`` must match for any ``scale > 0``. + """ + x = torch.randn(256, 256, device=self.device, dtype=torch.float32) + x = x / x.norm() # unit Frobenius norm direction + ref = muon_utils.newton_schulz(x, steps=5, coefficient_type="quintic") + out = muon_utils.newton_schulz(scale * x, steps=5, coefficient_type="quintic") + torch.testing.assert_close( + out, + ref, + atol=1e-4, + rtol=1e-5, + msg=lambda m: ( + f"newton_schulz not scale-invariant at input scale {scale}: " + f"||out||_F={out.norm().item():.4f} vs ||ref||_F={ref.norm().item():.4f}\n{m}" + ), + ) + @parameterized.parameters( (2, 256, 256), (4, 128, 256),