Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions emerging_optimizers/orthogonalized_optimizers/muon.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from emerging_optimizers import registry, triton_kernels
from emerging_optimizers.mixin import WeightDecayT
from emerging_optimizers.orthogonalized_optimizers import muon_utils
from emerging_optimizers.orthogonalized_optimizers.muon_utils import NSCoeffT
from emerging_optimizers.orthogonalized_optimizers.muon_utils import NSCoeffT, NSDTypeT
from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import OrthogonalizedOptimizer, _args_doc
from emerging_optimizers.utils import FP32MatmulPrecT

Expand Down Expand Up @@ -69,6 +69,8 @@ class Muon(OrthogonalizedOptimizer):
extra_scale_factor: The additional scale factor to use for the update. Setting it to 0.2 can closely match
the update RMS norm of AdamW as suggested by https://arxiv.org/abs/2502.16982.
use_syrk: Whether to use the Triton kernel for the Newton-Schulz iteration.
ns_dtype: Dtype used for the Newton-Schulz iteration state and intermediates. "auto" preserves the existing
behavior: use bfloat16 when ``fp32_matmul_prec`` is "medium", otherwise use float32.
"""

def __init__(
Expand All @@ -86,19 +88,29 @@ def __init__(
scale_mode: MuonScaleT = "spectral",
extra_scale_factor: float = 1.0,
use_syrk: bool = False,
ns_dtype: NSDTypeT = "auto",
) -> None:
if num_ns_steps < 1:
raise ValueError(f"num_ns_steps must be at least 1, got {num_ns_steps}")
Comment on lines +91 to 94

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Invalid ns_dtype not caught at construction time

_resolve_ns_dtype raises ValueError for an unrecognized ns_dtype string, but it is only called inside scaled_orthogonalize_fn — a closure executed at step time. A call like Muon(..., ns_dtype="fp16") will construct without error and raise only on the first optimizer step. For the same reason num_ns_steps is validated eagerly above, calling muon_utils._resolve_ns_dtype(ns_dtype) once here (discarding the return value) would catch typos at object creation time rather than mid-training.


muon_utils._validate_ns_dtype(ns_dtype)

if use_syrk:
if torch.cuda.is_available():
resolved_ns_dtype = muon_utils._resolve_ns_dtype(ns_dtype, fp32_matmul_prec)
if resolved_ns_dtype != torch.bfloat16:
logging.error(
f"use_syrk=True requires bfloat16 Newton-Schulz dtype, got {resolved_ns_dtype}. "
"Setting use_syrk to False."
)
use_syrk = False
elif torch.cuda.is_available():
sm_version = torch.cuda.get_device_capability()
else:
sm_version = (0, 0)
if not triton_kernels.HAS_TRITON_340: # type: ignore[attr-defined]
if use_syrk and not triton_kernels.HAS_TRITON_340: # type: ignore[attr-defined]
logging.error("Triton 3.4.0 or higher is required for use_syrk to be True.")
use_syrk = False
elif sm_version not in ((8, 0), (9, 0), (10, 0), (10, 3)):
elif use_syrk and sm_version not in ((8, 0), (9, 0), (10, 0), (10, 3)):
logging.error(
f"Correctness of Triton kernel on SM {sm_version} cannot be guaranteed. Setting use_syrk to False."
)
Expand All @@ -114,6 +126,7 @@ def scaled_orthogonalize_fn(grad: torch.Tensor) -> torch.Tensor:
steps=num_ns_steps,
coefficient_type=coefficient_type,
use_syrk=use_syrk,
ns_dtype=ns_dtype,
)
scale_factor = get_muon_scale_factor(grad.size(-2), grad.size(-1), mode=scale_mode)
return orth_grad * scale_factor * extra_scale_factor
Expand Down
68 changes: 60 additions & 8 deletions emerging_optimizers/orthogonalized_optimizers/muon_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,20 @@
from emerging_optimizers import triton_kernels


__all__ = ["newton_schulz", "newton_schulz_tp", "NSCoeffT", "get_coefficient_iterator", "CoeffIterMode"]
__all__ = [
"newton_schulz",
"newton_schulz_tp",
"NSCoeffT",
"NSDTypeT",
"get_coefficient_iterator",
"CoeffIterMode",
]

CoeffIterMode = Literal["cycle", "repeat_last"]

NSCoeffT = Literal["simple", "quintic", "polar_express", "cans", "aol", "deepseekv4", "custom", "cubic5"]
NSDTypeT = Literal["auto", "float32", "bfloat16", "float16"]
_VALID_NS_DTYPES = ("auto", "float32", "bfloat16", "float16")

_COEFFICIENT_SETS = {
# Values are rounded to closest representable in single precision.
Expand Down Expand Up @@ -136,6 +145,30 @@ def distributed_normalize_p2(x: torch.Tensor, eps: float, group: torch.distribut
return x / torch.sqrt(x_sq_sum).clamp_min(eps)


def _validate_ns_dtype(ns_dtype: NSDTypeT) -> None:
"""Validate the dtype used inside Newton-Schulz iterations."""
if ns_dtype not in _VALID_NS_DTYPES:
raise ValueError(f"Invalid Newton-Schulz dtype: {ns_dtype}")


def _resolve_ns_dtype(ns_dtype: NSDTypeT, fp32_matmul_prec: str | None = None) -> torch.dtype:
"""Resolve the dtype used inside Newton-Schulz iterations."""
_validate_ns_dtype(ns_dtype)
if ns_dtype == "auto":
if fp32_matmul_prec is not None:
return torch.bfloat16 if fp32_matmul_prec == "medium" else torch.float32
if torch.get_float32_matmul_precision() == "medium":
return torch.bfloat16
return torch.float32
if ns_dtype == "float32":
return torch.float32
if ns_dtype == "bfloat16":
return torch.bfloat16
if ns_dtype == "float16":
return torch.float16
raise AssertionError("unreachable")


def newton_schulz(
x: torch.Tensor,
steps: int,
Expand All @@ -145,6 +178,7 @@ def newton_schulz(
transpose: bool | None = None,
tp_group: torch.distributed.ProcessGroup | None = None,
use_syrk: bool = False,
ns_dtype: NSDTypeT = "auto",
) -> torch.Tensor:
"""Use Newton-Schulz iteration to compute the zeroth power / orthogonalization of x.

Expand Down Expand Up @@ -177,6 +211,8 @@ def newton_schulz(
If None, will be determined based on the size of the tensor.
tp_group: The process group for communication if input is distributed.
use_syrk: Whether to use the Triton kernel for the Newton-Schulz iteration.
ns_dtype: Dtype used for the Newton-Schulz iteration state and intermediates. "auto" preserves
the existing behavior: use bfloat16 when FP32 matmul precision is "medium", otherwise use float32.

Returns:
The orthogonalization of x.
Expand Down Expand Up @@ -218,19 +254,32 @@ def newton_schulz(
coeff_iter = get_coefficient_iterator(steps, coefficient_sets, mode=iter_mode)

ns_step_fn = newton_schulz_step if X.ndim == 2 else batched_newton_schulz_step
resolved_ns_dtype = _resolve_ns_dtype(ns_dtype)
# Perform the NS iterations
if torch.get_float32_matmul_precision() == "medium":
# PyTorch doesn't really have FP32 I/O BF16 compute kernels for precision "medium"
# We explicitly convert to BF16 and back to FP32.
if resolved_ns_dtype != torch.float32:
# PyTorch doesn't really have FP32 I/O lower-precision compute kernels for precision "medium".
# We explicitly convert the NS state to the selected dtype and back to FP32.
# NOTE: There is a small difference to calling FP32 I/O BF16 compute kernels because the final result
# is converted to BF16 before converting back to FP32. The rest should be the same as long as epilogue
# is always in FP32.
if use_syrk:
if use_syrk and resolved_ns_dtype == torch.bfloat16:
if X.ndim > 2:
raise TypeError("use_syrk does not support N-d input.")
ns_step_fn = newton_schulz_step_tsyrk
X = X.to(torch.bfloat16)
logging.log_first_n(logging.INFO, "Using BF16 I/O kernels for Newton-Schulz iteration.", 1)
elif use_syrk:
logging.log_first_n(
logging.WARNING,
"use_syrk=True requires bfloat16 Newton-Schulz dtype; falling back to addmm/baddbmm.",
1,
)
X = X.to(resolved_ns_dtype)
logging.log_first_n(logging.INFO, "Using %s I/O kernels for Newton-Schulz iteration.", 1, resolved_ns_dtype)
elif use_syrk:
logging.log_first_n(
logging.WARNING,
"use_syrk=True requires bfloat16 Newton-Schulz dtype; falling back to addmm/baddbmm.",
1,
)

for a, b, c in coeff_iter:
X = ns_step_fn(X, a, b, c, tp_group=tp_group)
Expand All @@ -251,6 +300,7 @@ def newton_schulz_tp(
tp_group: torch.distributed.ProcessGroup,
partition_dim: int | None = None,
tp_mode: Literal["duplicated", "distributed"] = "duplicated",
ns_dtype: NSDTypeT = "auto",
) -> torch.Tensor:
"""Tensor Parallel Newton-Schulz iteration.

Expand Down Expand Up @@ -278,14 +328,16 @@ def newton_schulz_tp(
partition_dim: The dimension to partition the tensor.
tp_group: The process group for communication if input is distributed.
tp_mode: The mode to use for the Newton-Schulz iteration.
ns_dtype: Dtype used for the Newton-Schulz iteration state and intermediates.
"""
if partition_dim is None:
# Fallback path for non TP params.
return newton_schulz(x, steps, coefficient_type)
return newton_schulz(x, steps, coefficient_type, ns_dtype=ns_dtype)

kwargs: Any = {
"steps": steps,
"coefficient_type": coefficient_type,
"ns_dtype": ns_dtype,
}

if tp_mode == "duplicated":
Expand Down
31 changes: 31 additions & 0 deletions tests/test_muon_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,37 @@ def test_newtonschulz_custom_coeff_close_to_reference(self, dim1, dim2):
rtol=1e-6,
)

@parameterized.parameters("float32", "bfloat16", "float16")
def test_newton_schulz_explicit_ns_dtype_returns_fp32(self, ns_dtype):
x = torch.randn(17, 9, device=self.device, dtype=torch.float32)
out = muon_utils.newton_schulz(x, steps=1, coefficient_type="simple", ns_dtype=ns_dtype)

self.assertEqual(out.dtype, torch.float32)
self.assertTrue(torch.isfinite(out).all().item())

def test_newton_schulz_auto_ns_dtype_preserves_medium_precision_behavior(self):
x = torch.randn(17, 9, device=self.device, dtype=torch.float32)

with utils.fp32_matmul_precision("medium"):
out_auto = muon_utils.newton_schulz(x, steps=1, coefficient_type="simple", ns_dtype="auto")
out_bf16 = muon_utils.newton_schulz(x, steps=1, coefficient_type="simple", ns_dtype="bfloat16")

torch.testing.assert_close(out_auto, out_bf16, atol=0, rtol=0)

def test_newton_schulz_auto_ns_dtype_preserves_highest_precision_behavior(self):
x = torch.randn(17, 9, device=self.device, dtype=torch.float32)

with utils.fp32_matmul_precision("highest"):
out_auto = muon_utils.newton_schulz(x, steps=1, coefficient_type="simple", ns_dtype="auto")
out_fp32 = muon_utils.newton_schulz(x, steps=1, coefficient_type="simple", ns_dtype="float32")

torch.testing.assert_close(out_auto, out_fp32, atol=0, rtol=0)

def test_newton_schulz_invalid_ns_dtype_raises_value_error(self) -> None:
x = torch.randn(5, 7, device=self.device, dtype=torch.float32)
with self.assertRaisesRegex(ValueError, "Invalid Newton-Schulz dtype"):
muon_utils.newton_schulz(x, steps=5, coefficient_type="quintic", ns_dtype="invalid")

@parameterized.product(
size=[(512, 512), (512, 256), (256, 512)],
coefficient_type=["polar_express", "deepseekv4"],
Expand Down
35 changes: 34 additions & 1 deletion tests/test_orthogonalized_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from absl import flags, logging
from absl.testing import absltest, parameterized

from emerging_optimizers.orthogonalized_optimizers import mop, muon, muon_hyperball, polargrad, scion
from emerging_optimizers.orthogonalized_optimizers import mop, muon, muon_hyperball, muon_utils, polargrad, scion
from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import OrthogonalizedOptimizer


Expand Down Expand Up @@ -272,6 +272,34 @@ def test_use_syrk_match_without_syrk(self) -> None:
ref_param.data,
)

def test_ns_dtype_passed_to_muon_orthogonalization(self) -> None:
shape = (17, 9)
test_param = nn.Parameter(torch.zeros(shape, dtype=torch.float32, device=self.device))
test_param.grad = torch.randint(-5, 5, shape, dtype=torch.float32, device=self.device)
grad = test_param.grad.clone()

muon_opt = muon.Muon(
[test_param],
lr=1.0,
momentum=0.0,
weight_decay=0.0,
coefficient_type="simple",
num_ns_steps=1,
fp32_matmul_prec="highest",
ns_dtype="float16",
)
muon_opt.step()

expected_update = muon_utils.newton_schulz(
grad,
steps=1,
coefficient_type="simple",
ns_dtype="float16",
)
expected_update *= muon.get_muon_scale_factor(shape[0], shape[1], mode="spectral")

torch.testing.assert_close(test_param.data, -expected_update, atol=0, rtol=0)

def test_use_independent_wd(self) -> None:
"""Test that use_independent_wd properly decouples weight decay from learning rate."""
shape = (32, 32)
Expand Down Expand Up @@ -306,6 +334,11 @@ def test_zero_num_ns_steps_raises_value_error(self) -> None:
with self.assertRaisesRegex(ValueError, "num_ns_steps must be at least 1"):
muon.Muon([test_param], num_ns_steps=0)

def test_invalid_ns_dtype_raises_value_error(self) -> None:
test_param = nn.Parameter(torch.randn(5, 7, dtype=torch.float32, device=self.device))
with self.assertRaisesRegex(ValueError, "Invalid Newton-Schulz dtype"):
muon.Muon([test_param], ns_dtype="invalid")

def test_invalid_scale_mode_raises_value_error(self) -> None:
"""Test that get_muon_scale_factor raises ValueError for invalid mode."""
with self.assertRaisesRegex(ValueError, "Invalid mode.*invalid_mode"):
Expand Down