diff --git a/emerging_optimizers/orthogonalized_optimizers/muon.py b/emerging_optimizers/orthogonalized_optimizers/muon.py index e962239..80e1fec 100644 --- a/emerging_optimizers/orthogonalized_optimizers/muon.py +++ b/emerging_optimizers/orthogonalized_optimizers/muon.py @@ -22,7 +22,7 @@ from emerging_optimizers import registry, triton_kernels from emerging_optimizers.mixin import WeightDecayT from emerging_optimizers.orthogonalized_optimizers import muon_utils -from emerging_optimizers.orthogonalized_optimizers.muon_utils import NSCoeffT +from emerging_optimizers.orthogonalized_optimizers.muon_utils import NSCoeffT, NSDTypeT from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import OrthogonalizedOptimizer, _args_doc from emerging_optimizers.utils import FP32MatmulPrecT @@ -69,6 +69,8 @@ class Muon(OrthogonalizedOptimizer): extra_scale_factor: The additional scale factor to use for the update. Setting it to 0.2 can closely match the update RMS norm of AdamW as suggested by https://arxiv.org/abs/2502.16982. use_syrk: Whether to use the Triton kernel for the Newton-Schulz iteration. + ns_dtype: Dtype used for the Newton-Schulz iteration state and intermediates. "auto" preserves the existing + behavior: use bfloat16 when ``fp32_matmul_prec`` is "medium", otherwise use float32. """ def __init__( @@ -86,19 +88,29 @@ def __init__( scale_mode: MuonScaleT = "spectral", extra_scale_factor: float = 1.0, use_syrk: bool = False, + ns_dtype: NSDTypeT = "auto", ) -> None: if num_ns_steps < 1: raise ValueError(f"num_ns_steps must be at least 1, got {num_ns_steps}") + muon_utils._validate_ns_dtype(ns_dtype) + if use_syrk: - if torch.cuda.is_available(): + resolved_ns_dtype = muon_utils._resolve_ns_dtype(ns_dtype, fp32_matmul_prec) + if resolved_ns_dtype != torch.bfloat16: + logging.error( + f"use_syrk=True requires bfloat16 Newton-Schulz dtype, got {resolved_ns_dtype}. " + "Setting use_syrk to False." + ) + use_syrk = False + elif torch.cuda.is_available(): sm_version = torch.cuda.get_device_capability() else: sm_version = (0, 0) - if not triton_kernels.HAS_TRITON_340: # type: ignore[attr-defined] + if use_syrk and not triton_kernels.HAS_TRITON_340: # type: ignore[attr-defined] logging.error("Triton 3.4.0 or higher is required for use_syrk to be True.") use_syrk = False - elif sm_version not in ((8, 0), (9, 0), (10, 0), (10, 3)): + elif use_syrk and sm_version not in ((8, 0), (9, 0), (10, 0), (10, 3)): logging.error( f"Correctness of Triton kernel on SM {sm_version} cannot be guaranteed. Setting use_syrk to False." ) @@ -114,6 +126,7 @@ def scaled_orthogonalize_fn(grad: torch.Tensor) -> torch.Tensor: steps=num_ns_steps, coefficient_type=coefficient_type, use_syrk=use_syrk, + ns_dtype=ns_dtype, ) scale_factor = get_muon_scale_factor(grad.size(-2), grad.size(-1), mode=scale_mode) return orth_grad * scale_factor * extra_scale_factor diff --git a/emerging_optimizers/orthogonalized_optimizers/muon_utils.py b/emerging_optimizers/orthogonalized_optimizers/muon_utils.py index b29e02b..d227cb4 100644 --- a/emerging_optimizers/orthogonalized_optimizers/muon_utils.py +++ b/emerging_optimizers/orthogonalized_optimizers/muon_utils.py @@ -21,11 +21,20 @@ from emerging_optimizers import triton_kernels -__all__ = ["newton_schulz", "newton_schulz_tp", "NSCoeffT", "get_coefficient_iterator", "CoeffIterMode"] +__all__ = [ + "newton_schulz", + "newton_schulz_tp", + "NSCoeffT", + "NSDTypeT", + "get_coefficient_iterator", + "CoeffIterMode", +] CoeffIterMode = Literal["cycle", "repeat_last"] NSCoeffT = Literal["simple", "quintic", "polar_express", "cans", "aol", "deepseekv4", "custom", "cubic5"] +NSDTypeT = Literal["auto", "float32", "bfloat16", "float16"] +_VALID_NS_DTYPES = ("auto", "float32", "bfloat16", "float16") _COEFFICIENT_SETS = { # Values are rounded to closest representable in single precision. @@ -136,6 +145,30 @@ def distributed_normalize_p2(x: torch.Tensor, eps: float, group: torch.distribut return x / torch.sqrt(x_sq_sum).clamp_min(eps) +def _validate_ns_dtype(ns_dtype: NSDTypeT) -> None: + """Validate the dtype used inside Newton-Schulz iterations.""" + if ns_dtype not in _VALID_NS_DTYPES: + raise ValueError(f"Invalid Newton-Schulz dtype: {ns_dtype}") + + +def _resolve_ns_dtype(ns_dtype: NSDTypeT, fp32_matmul_prec: str | None = None) -> torch.dtype: + """Resolve the dtype used inside Newton-Schulz iterations.""" + _validate_ns_dtype(ns_dtype) + if ns_dtype == "auto": + if fp32_matmul_prec is not None: + return torch.bfloat16 if fp32_matmul_prec == "medium" else torch.float32 + if torch.get_float32_matmul_precision() == "medium": + return torch.bfloat16 + return torch.float32 + if ns_dtype == "float32": + return torch.float32 + if ns_dtype == "bfloat16": + return torch.bfloat16 + if ns_dtype == "float16": + return torch.float16 + raise AssertionError("unreachable") + + def newton_schulz( x: torch.Tensor, steps: int, @@ -145,6 +178,7 @@ def newton_schulz( transpose: bool | None = None, tp_group: torch.distributed.ProcessGroup | None = None, use_syrk: bool = False, + ns_dtype: NSDTypeT = "auto", ) -> torch.Tensor: """Use Newton-Schulz iteration to compute the zeroth power / orthogonalization of x. @@ -177,6 +211,8 @@ def newton_schulz( If None, will be determined based on the size of the tensor. tp_group: The process group for communication if input is distributed. use_syrk: Whether to use the Triton kernel for the Newton-Schulz iteration. + ns_dtype: Dtype used for the Newton-Schulz iteration state and intermediates. "auto" preserves + the existing behavior: use bfloat16 when FP32 matmul precision is "medium", otherwise use float32. Returns: The orthogonalization of x. @@ -218,19 +254,32 @@ def newton_schulz( coeff_iter = get_coefficient_iterator(steps, coefficient_sets, mode=iter_mode) ns_step_fn = newton_schulz_step if X.ndim == 2 else batched_newton_schulz_step + resolved_ns_dtype = _resolve_ns_dtype(ns_dtype) # Perform the NS iterations - if torch.get_float32_matmul_precision() == "medium": - # PyTorch doesn't really have FP32 I/O BF16 compute kernels for precision "medium" - # We explicitly convert to BF16 and back to FP32. + if resolved_ns_dtype != torch.float32: + # PyTorch doesn't really have FP32 I/O lower-precision compute kernels for precision "medium". + # We explicitly convert the NS state to the selected dtype and back to FP32. # NOTE: There is a small difference to calling FP32 I/O BF16 compute kernels because the final result # is converted to BF16 before converting back to FP32. The rest should be the same as long as epilogue # is always in FP32. - if use_syrk: + if use_syrk and resolved_ns_dtype == torch.bfloat16: if X.ndim > 2: raise TypeError("use_syrk does not support N-d input.") ns_step_fn = newton_schulz_step_tsyrk - X = X.to(torch.bfloat16) - logging.log_first_n(logging.INFO, "Using BF16 I/O kernels for Newton-Schulz iteration.", 1) + elif use_syrk: + logging.log_first_n( + logging.WARNING, + "use_syrk=True requires bfloat16 Newton-Schulz dtype; falling back to addmm/baddbmm.", + 1, + ) + X = X.to(resolved_ns_dtype) + logging.log_first_n(logging.INFO, "Using %s I/O kernels for Newton-Schulz iteration.", 1, resolved_ns_dtype) + elif use_syrk: + logging.log_first_n( + logging.WARNING, + "use_syrk=True requires bfloat16 Newton-Schulz dtype; falling back to addmm/baddbmm.", + 1, + ) for a, b, c in coeff_iter: X = ns_step_fn(X, a, b, c, tp_group=tp_group) @@ -251,6 +300,7 @@ def newton_schulz_tp( tp_group: torch.distributed.ProcessGroup, partition_dim: int | None = None, tp_mode: Literal["duplicated", "distributed"] = "duplicated", + ns_dtype: NSDTypeT = "auto", ) -> torch.Tensor: """Tensor Parallel Newton-Schulz iteration. @@ -278,14 +328,16 @@ def newton_schulz_tp( partition_dim: The dimension to partition the tensor. tp_group: The process group for communication if input is distributed. tp_mode: The mode to use for the Newton-Schulz iteration. + ns_dtype: Dtype used for the Newton-Schulz iteration state and intermediates. """ if partition_dim is None: # Fallback path for non TP params. - return newton_schulz(x, steps, coefficient_type) + return newton_schulz(x, steps, coefficient_type, ns_dtype=ns_dtype) kwargs: Any = { "steps": steps, "coefficient_type": coefficient_type, + "ns_dtype": ns_dtype, } if tp_mode == "duplicated": diff --git a/tests/test_muon_utils.py b/tests/test_muon_utils.py index ac58f55..bd2868a 100644 --- a/tests/test_muon_utils.py +++ b/tests/test_muon_utils.py @@ -199,6 +199,37 @@ def test_newtonschulz_custom_coeff_close_to_reference(self, dim1, dim2): rtol=1e-6, ) + @parameterized.parameters("float32", "bfloat16", "float16") + def test_newton_schulz_explicit_ns_dtype_returns_fp32(self, ns_dtype): + x = torch.randn(17, 9, device=self.device, dtype=torch.float32) + out = muon_utils.newton_schulz(x, steps=1, coefficient_type="simple", ns_dtype=ns_dtype) + + self.assertEqual(out.dtype, torch.float32) + self.assertTrue(torch.isfinite(out).all().item()) + + def test_newton_schulz_auto_ns_dtype_preserves_medium_precision_behavior(self): + x = torch.randn(17, 9, device=self.device, dtype=torch.float32) + + with utils.fp32_matmul_precision("medium"): + out_auto = muon_utils.newton_schulz(x, steps=1, coefficient_type="simple", ns_dtype="auto") + out_bf16 = muon_utils.newton_schulz(x, steps=1, coefficient_type="simple", ns_dtype="bfloat16") + + torch.testing.assert_close(out_auto, out_bf16, atol=0, rtol=0) + + def test_newton_schulz_auto_ns_dtype_preserves_highest_precision_behavior(self): + x = torch.randn(17, 9, device=self.device, dtype=torch.float32) + + with utils.fp32_matmul_precision("highest"): + out_auto = muon_utils.newton_schulz(x, steps=1, coefficient_type="simple", ns_dtype="auto") + out_fp32 = muon_utils.newton_schulz(x, steps=1, coefficient_type="simple", ns_dtype="float32") + + torch.testing.assert_close(out_auto, out_fp32, atol=0, rtol=0) + + def test_newton_schulz_invalid_ns_dtype_raises_value_error(self) -> None: + x = torch.randn(5, 7, device=self.device, dtype=torch.float32) + with self.assertRaisesRegex(ValueError, "Invalid Newton-Schulz dtype"): + muon_utils.newton_schulz(x, steps=5, coefficient_type="quintic", ns_dtype="invalid") + @parameterized.product( size=[(512, 512), (512, 256), (256, 512)], coefficient_type=["polar_express", "deepseekv4"], diff --git a/tests/test_orthogonalized_optimizer.py b/tests/test_orthogonalized_optimizer.py index 252252a..1d1674a 100644 --- a/tests/test_orthogonalized_optimizer.py +++ b/tests/test_orthogonalized_optimizer.py @@ -18,7 +18,7 @@ from absl import flags, logging from absl.testing import absltest, parameterized -from emerging_optimizers.orthogonalized_optimizers import mop, muon, muon_hyperball, polargrad, scion +from emerging_optimizers.orthogonalized_optimizers import mop, muon, muon_hyperball, muon_utils, polargrad, scion from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import OrthogonalizedOptimizer @@ -272,6 +272,34 @@ def test_use_syrk_match_without_syrk(self) -> None: ref_param.data, ) + def test_ns_dtype_passed_to_muon_orthogonalization(self) -> None: + shape = (17, 9) + test_param = nn.Parameter(torch.zeros(shape, dtype=torch.float32, device=self.device)) + test_param.grad = torch.randint(-5, 5, shape, dtype=torch.float32, device=self.device) + grad = test_param.grad.clone() + + muon_opt = muon.Muon( + [test_param], + lr=1.0, + momentum=0.0, + weight_decay=0.0, + coefficient_type="simple", + num_ns_steps=1, + fp32_matmul_prec="highest", + ns_dtype="float16", + ) + muon_opt.step() + + expected_update = muon_utils.newton_schulz( + grad, + steps=1, + coefficient_type="simple", + ns_dtype="float16", + ) + expected_update *= muon.get_muon_scale_factor(shape[0], shape[1], mode="spectral") + + torch.testing.assert_close(test_param.data, -expected_update, atol=0, rtol=0) + def test_use_independent_wd(self) -> None: """Test that use_independent_wd properly decouples weight decay from learning rate.""" shape = (32, 32) @@ -306,6 +334,11 @@ def test_zero_num_ns_steps_raises_value_error(self) -> None: with self.assertRaisesRegex(ValueError, "num_ns_steps must be at least 1"): muon.Muon([test_param], num_ns_steps=0) + def test_invalid_ns_dtype_raises_value_error(self) -> None: + test_param = nn.Parameter(torch.randn(5, 7, dtype=torch.float32, device=self.device)) + with self.assertRaisesRegex(ValueError, "Invalid Newton-Schulz dtype"): + muon.Muon([test_param], ns_dtype="invalid") + def test_invalid_scale_mode_raises_value_error(self) -> None: """Test that get_muon_scale_factor raises ValueError for invalid mode.""" with self.assertRaisesRegex(ValueError, "Invalid mode.*invalid_mode"):