Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 28 additions & 9 deletions emerging_optimizers/orthogonalized_optimizers/muon_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,21 @@ def get_coefficient_iterator(
return islice(base, steps)


def distributed_normalize_p2(x: torch.Tensor, eps: float, group: torch.distributed.ProcessGroup) -> torch.Tensor:
"""Normalize a tensor in a distributed way."""
x_sq_sum = (x * x).sum()
def distributed_normalize_p2(
x: torch.Tensor, eps: float, group: torch.distributed.ProcessGroup, normalize_in_double: bool = False
) -> torch.Tensor:
"""Normalize a tensor by its distributed Frobenius norm.

When ``normalize_in_double`` is set, the squared sum is accumulated in float64 so that tiny
entries do not underflow to zero when squared in float32.
"""
x_sq = x.double() if normalize_in_double else x
x_sq_sum = (x_sq * x_sq).sum()
torch.distributed.all_reduce(x_sq_sum, op=torch.distributed.ReduceOp.SUM, group=group)
return x / torch.sqrt(x_sq_sum).clamp_min(eps)
norm = torch.sqrt(x_sq_sum).to(x.dtype)
if not normalize_in_double:
norm.clamp_min_(eps)
return x / norm
Comment thread
skyw marked this conversation as resolved.


def newton_schulz(
Expand All @@ -145,6 +155,7 @@ def newton_schulz(
transpose: bool | None = None,
tp_group: torch.distributed.ProcessGroup | None = None,
use_syrk: bool = False,
normalize_in_double: bool = False,
) -> torch.Tensor:
"""Use Newton-Schulz iteration to compute the zeroth power / orthogonalization of x.

Expand Down Expand Up @@ -177,6 +188,10 @@ def newton_schulz(
If None, will be determined based on the size of the tensor.
tp_group: The process group for communication if input is distributed.
use_syrk: Whether to use the Triton kernel for the Newton-Schulz iteration.
normalize_in_double: Whether to reduce the Frobenius norm in float64. This keeps the squared
sum out of float32 underflow for inputs with very small entries, at the cost of a float64
reduction. Without customized kernels, manually handle scaling without triggering a device to host
sync are usually more expensive than using double.

Returns:
The orthogonalization of x.
Expand All @@ -192,13 +207,17 @@ def newton_schulz(
if transpose:
x = x.mT

# Ensure spectral norm is at most 1.
# NOTE: ``eps`` is a divide-by-zero guard; it must stay well below any realistic ``||x||_F``
# yet remain fp32-safe when squared. See issue #229.
# Ensure spectral norm is at most 1 by normalizing with the Frobenius norm. Reducing in float64
# (``normalize_in_double``) keeps the squared sum out of float32 underflow for tiny-norm inputs.
if tp_group is not None:
X = distributed_normalize_p2(x, eps, tp_group)
X = distributed_normalize_p2(x, eps, tp_group, normalize_in_double)
else:
X = torch.nn.functional.normalize(x, p=2, dim=(-2, -1), eps=eps) # type: ignore[arg-type]
if not normalize_in_double:
X = torch.nn.functional.normalize(x, p=2, dim=(-2, -1), eps=eps) # type: ignore[arg-type]
else:
logging.warning("eps is ignored when normalize in double.")
norm = torch.linalg.vector_norm(x, dim=(-2, -1), keepdim=True, dtype=torch.float64).to(x.dtype)
X = x / norm

if coefficient_type in _COEFFICIENT_SETS:
coefficient_sets = _COEFFICIENT_SETS[coefficient_type]
Expand Down
35 changes: 14 additions & 21 deletions tests/test_muon_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,28 +121,21 @@ def test_newtonschulz5_close_to_reference(self, dim1, dim2):
rtol=1e-7,
)

@parameterized.parameters(1e-2, 1e-6, 1e-9, 1e-12)
def test_newtonschulz_small_eps(self, scale):
"""Orthogonalization depends only on direction, so scaling the input must not change the output.

Regression test for issue #229: a too-large ``eps`` in the internal ``F.normalize`` divides
small-norm inputs by ``eps`` instead of their norm, silently degenerating the output. The
orthogonalized result for ``x`` and ``scale * x`` must match for any ``scale > 0``.
"""
@parameterized.parameters(-20, -40, -60)
def test_normalization_scale_invariant(self, exp2):
x = torch.randn(256, 256, device=self.device, dtype=torch.float32)
x = x / x.norm() # unit Frobenius norm direction
ref = muon_utils.newton_schulz(x, steps=5, coefficient_type="quintic")
out = muon_utils.newton_schulz(scale * x, steps=5, coefficient_type="quintic")
torch.testing.assert_close(
out,
ref,
atol=1e-4,
rtol=1e-5,
msg=lambda m: (
f"newton_schulz not scale-invariant at input scale {scale}: "
f"||out||_F={out.norm().item():.4f} vs ||ref||_F={ref.norm().item():.4f}\n{m}"
),
)
ref = muon_utils.newton_schulz(x, steps=0, eps=0)
out = muon_utils.newton_schulz(2**exp2 * x, steps=0, eps=0, normalize_in_double=True)
assert_equal(ref, out)
Comment thread
skyw marked this conversation as resolved.

def test_preserve_values_with_underflowed_norm_in_fp64(self):
scale = 1e-30
x = torch.randn(256, 256, device=self.device, dtype=torch.float32) * scale
assert torch.linalg.vector_norm(x) == 0 # should underflow
norm_ref = torch.linalg.vector_norm(x, dtype=torch.double)
assert norm_ref != 0
out = muon_utils.newton_schulz(x, steps=0, normalize_in_double=True)
torch.testing.assert_close(x / norm_ref, out, atol=0, rtol=1e-6)

@parameterized.parameters(
(2, 256, 256),
Expand Down
Loading