diff --git a/emerging_optimizers/orthogonalized_optimizers/muon_utils.py b/emerging_optimizers/orthogonalized_optimizers/muon_utils.py index 521fc71..781ce70 100644 --- a/emerging_optimizers/orthogonalized_optimizers/muon_utils.py +++ b/emerging_optimizers/orthogonalized_optimizers/muon_utils.py @@ -129,11 +129,21 @@ def get_coefficient_iterator( return islice(base, steps) -def distributed_normalize_p2(x: torch.Tensor, eps: float, group: torch.distributed.ProcessGroup) -> torch.Tensor: - """Normalize a tensor in a distributed way.""" - x_sq_sum = (x * x).sum() +def distributed_normalize_p2( + x: torch.Tensor, eps: float, group: torch.distributed.ProcessGroup, normalize_in_double: bool = False +) -> torch.Tensor: + """Normalize a tensor by its distributed Frobenius norm. + + When ``normalize_in_double`` is set, the squared sum is accumulated in float64 so that tiny + entries do not underflow to zero when squared in float32. + """ + x_sq = x.double() if normalize_in_double else x + x_sq_sum = (x_sq * x_sq).sum() torch.distributed.all_reduce(x_sq_sum, op=torch.distributed.ReduceOp.SUM, group=group) - return x / torch.sqrt(x_sq_sum).clamp_min(eps) + norm = torch.sqrt(x_sq_sum).to(x.dtype) + if not normalize_in_double: + norm.clamp_min_(eps) + return x / norm def newton_schulz( @@ -145,6 +155,7 @@ def newton_schulz( transpose: bool | None = None, tp_group: torch.distributed.ProcessGroup | None = None, use_syrk: bool = False, + normalize_in_double: bool = False, ) -> torch.Tensor: """Use Newton-Schulz iteration to compute the zeroth power / orthogonalization of x. @@ -177,6 +188,10 @@ def newton_schulz( If None, will be determined based on the size of the tensor. tp_group: The process group for communication if input is distributed. use_syrk: Whether to use the Triton kernel for the Newton-Schulz iteration. + normalize_in_double: Whether to reduce the Frobenius norm in float64. This keeps the squared + sum out of float32 underflow for inputs with very small entries, at the cost of a float64 + reduction. Without customized kernels, manually handle scaling without triggering a device to host + sync are usually more expensive than using double. Returns: The orthogonalization of x. @@ -192,13 +207,17 @@ def newton_schulz( if transpose: x = x.mT - # Ensure spectral norm is at most 1. - # NOTE: ``eps`` is a divide-by-zero guard; it must stay well below any realistic ``||x||_F`` - # yet remain fp32-safe when squared. See issue #229. + # Ensure spectral norm is at most 1 by normalizing with the Frobenius norm. Reducing in float64 + # (``normalize_in_double``) keeps the squared sum out of float32 underflow for tiny-norm inputs. if tp_group is not None: - X = distributed_normalize_p2(x, eps, tp_group) + X = distributed_normalize_p2(x, eps, tp_group, normalize_in_double) else: - X = torch.nn.functional.normalize(x, p=2, dim=(-2, -1), eps=eps) # type: ignore[arg-type] + if not normalize_in_double: + X = torch.nn.functional.normalize(x, p=2, dim=(-2, -1), eps=eps) # type: ignore[arg-type] + else: + # eps is ignored when normalize in double. + norm = torch.linalg.vector_norm(x, dim=(-2, -1), keepdim=True, dtype=torch.float64).to(x.dtype) + X = x / norm if coefficient_type in _COEFFICIENT_SETS: coefficient_sets = _COEFFICIENT_SETS[coefficient_type] diff --git a/tests/test_muon_utils.py b/tests/test_muon_utils.py index 38585c1..4c45848 100644 --- a/tests/test_muon_utils.py +++ b/tests/test_muon_utils.py @@ -121,28 +121,14 @@ def test_newtonschulz5_close_to_reference(self, dim1, dim2): rtol=1e-7, ) - @parameterized.parameters(1e-2, 1e-6, 1e-9, 1e-12) - def test_newtonschulz_small_eps(self, scale): - """Orthogonalization depends only on direction, so scaling the input must not change the output. - - Regression test for issue #229: a too-large ``eps`` in the internal ``F.normalize`` divides - small-norm inputs by ``eps`` instead of their norm, silently degenerating the output. The - orthogonalized result for ``x`` and ``scale * x`` must match for any ``scale > 0``. - """ - x = torch.randn(256, 256, device=self.device, dtype=torch.float32) - x = x / x.norm() # unit Frobenius norm direction - ref = muon_utils.newton_schulz(x, steps=5, coefficient_type="quintic") - out = muon_utils.newton_schulz(scale * x, steps=5, coefficient_type="quintic") - torch.testing.assert_close( - out, - ref, - atol=1e-4, - rtol=1e-5, - msg=lambda m: ( - f"newton_schulz not scale-invariant at input scale {scale}: " - f"||out||_F={out.norm().item():.4f} vs ||ref||_F={ref.norm().item():.4f}\n{m}" - ), - ) + def test_preserve_values_with_underflowed_norm_in_fp64(self): + scale = 1e-30 + x = torch.randn(256, 256, device=self.device, dtype=torch.float32) * scale + assert torch.linalg.vector_norm(x) == 0 # should underflow + norm_ref = torch.linalg.vector_norm(x, dtype=torch.double) + assert norm_ref != 0 + out = muon_utils.newton_schulz(x, steps=0, normalize_in_double=True) + torch.testing.assert_close(x / norm_ref, out, atol=0, rtol=1e-6) @parameterized.parameters( (2, 256, 256),