From f74b86996856e0304f1191ade1e7c3ad0a35bce2 Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Tue, 23 Jun 2026 19:15:46 -0700 Subject: [PATCH 1/4] ai add muown Signed-off-by: Hao Wu --- .../orthogonalized_optimizers/__init__.py | 1 + .../orthogonalized_optimizers/muown.py | 221 ++++++++++++++++++ 2 files changed, 222 insertions(+) create mode 100644 emerging_optimizers/orthogonalized_optimizers/muown.py diff --git a/emerging_optimizers/orthogonalized_optimizers/__init__.py b/emerging_optimizers/orthogonalized_optimizers/__init__.py index 1e381a7d..8c784085 100644 --- a/emerging_optimizers/orthogonalized_optimizers/__init__.py +++ b/emerging_optimizers/orthogonalized_optimizers/__init__.py @@ -16,6 +16,7 @@ from emerging_optimizers.orthogonalized_optimizers.mop import * from emerging_optimizers.orthogonalized_optimizers.muon import * from emerging_optimizers.orthogonalized_optimizers.muon_hyperball import * +from emerging_optimizers.orthogonalized_optimizers.muown import * from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import * from emerging_optimizers.orthogonalized_optimizers.polargrad import * from emerging_optimizers.orthogonalized_optimizers.scion import * diff --git a/emerging_optimizers/orthogonalized_optimizers/muown.py b/emerging_optimizers/orthogonalized_optimizers/muown.py new file mode 100644 index 00000000..28f7ce7e --- /dev/null +++ b/emerging_optimizers/orthogonalized_optimizers/muown.py @@ -0,0 +1,221 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Callable, override + + +if TYPE_CHECKING: + from typing import overload + +import torch +from torch.optim.optimizer import ParamsT + +from emerging_optimizers import registry, utils +from emerging_optimizers.orthogonalized_optimizers.muon import Muon, MuonScaleT +from emerging_optimizers.orthogonalized_optimizers.muon_utils import NSCoeffT +from emerging_optimizers.scalar_optimizers import update_functions +from emerging_optimizers.utils import FP32MatmulPrecT + + +__all__ = ["Muown"] + + +@registry.register_optimizer("muown") +class Muown(Muon): + """Muown: Muon with internal weight normalization (row-norm control). + + Muown (Lion et al., *Muown: Row-Norm Control for Muon Optimization*, arXiv:2605.10797) is a drop-in + replacement for :class:`~emerging_optimizers.orthogonalized_optimizers.muon.Muon` that splits each 2D + weight into a per-row magnitude and a direction, then optimizes them under their natural geometries: + + - **Direction** ``v``: the Muon update (EMA momentum + Newton-Schulz orthogonalization), reusing the + parent's ``scaled_orthogonalize_fn``. + - **Magnitude** ``g`` (one entry per output row): Adam, which realizes the :math:`\\ell_\\infty` duality + map for the diagonal per-neuron gain. + + The reparameterization ``W = Diag(g / ||v||_row) v`` is held implicitly inside the optimizer, so the + forward pass is unchanged. At init ``g``, ``v`` are seeded from ``W`` so Muown starts from the same + point as Muon. The row magnitude is the empirical driver of spectral-norm drift under Muon; making it an + explicit, separately-optimized variable controls that drift without the indiscriminate shrinkage of + plain weight decay. + + A single ``lr`` drives both halves: the direction step carries the ``0.2 * sqrt(max(m, n))`` scaling + (via ``extra_scale_factor``) that matches Adam's update RMS norm, so no separate magnitude lr is needed. + + State per parameter: + - ``step`` + - ``g``: per-row magnitude, shape ``[out, 1]``. + - ``v_norm``: cached row norms ``||v||_row`` of the direction, shape ``[out, 1]``. + - ``momentum_buffer``: EMA momentum of the direction gradient (the Muon momentum on ``v``). + - ``m_g``, ``v_g``: Adam first/second moments of the magnitude gradient. + + Note: + Weight decay is decoupled and applied to the magnitude ``g`` (which is exactly the spectral-norm + driver). Because ``W`` is recomposed from ``g`` after the decay, the invariant ``||W_row|| == g`` + holds without a separate resync. + + Warning: + - This optimizer requires that all parameters passed in are 2D. + - It should not be used for the embedding layer, the final fully connected layer, or any 1-D + parameters; those can all be optimized by a standard method (e.g., AdamW). + - This optimizer is experimental and may change in future versions. + + Args: + params: Iterable of parameters to optimize or dicts defining parameter groups. + lr: Learning rate shared by the direction and magnitude updates. + momentum: EMA momentum for the direction (Muon) update. + betas: Adam ``(beta1, beta2)`` for the magnitude update. + adam_eps: Adam epsilon for the magnitude update. + weight_decay: Decoupled weight decay coefficient, applied to the magnitude ``g``. + fp32_matmul_prec: Precision for the orthogonalization GEMM operations. + coefficient_type: Newton-Schulz coefficient set (see :class:`Muon`). + num_ns_steps: Number of Newton-Schulz iteration steps. + scale_mode: Update scale mode (see :func:`~emerging_optimizers.orthogonalized_optimizers.muon.get_muon_scale_factor`). + extra_scale_factor: Extra scale on the direction update; ``0.2`` matches Adam's update RMS norm. + use_syrk: Whether to use the Triton SYRK kernel for Newton-Schulz. + """ + + def __init__( + self, + params: ParamsT, + lr: float = 3e-4, + momentum: float = 0.95, + weight_decay: float = 0.0, + *, + betas: tuple[float, float] = (0.9, 0.95), + adam_eps: float = 1e-8, + fp32_matmul_prec: FP32MatmulPrecT = "medium", + coefficient_type: NSCoeffT = "quintic", + num_ns_steps: int = 5, + scale_mode: MuonScaleT = "spectral", + extra_scale_factor: float = 1.0, + use_syrk: bool = False, + ) -> None: + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f"Invalid beta1: {betas[0]}") + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f"Invalid beta2: {betas[1]}") + + self.betas = betas + self.adam_eps = adam_eps + + super().__init__( + params, + lr, + momentum, + weight_decay, + nesterov=False, + weight_decay_method="decoupled", + fp32_matmul_prec=fp32_matmul_prec, + coefficient_type=coefficient_type, + num_ns_steps=num_ns_steps, + scale_mode=scale_mode, + extra_scale_factor=extra_scale_factor, + use_syrk=use_syrk, + ) + + @torch.no_grad() # type: ignore[misc] + @override + def _init_group(self, group: dict, skip_non_grad_params: bool = True) -> None: + for p in group["params"]: + if skip_non_grad_params and p.grad is None: + continue + if p.dim() != 2: + raise TypeError("Muown is only supported for 2D parameters") + state = self.state[p] + if len(state) == 0: + # Seed g, v from the current weight so Muown starts from the same point as Muon. + row_norm = p.norm(dim=1, keepdim=True).to(torch.float32) + state["step"] = 0 + state["g"] = row_norm.clone() + state["v_norm"] = row_norm.clone() + state["momentum_buffer"] = torch.zeros_like(p, dtype=torch.float32) + state["m_g"] = torch.zeros_like(row_norm) + state["v_g"] = torch.zeros_like(row_norm) + + if TYPE_CHECKING: + + @overload + def step(self, closure: None = ...) -> None: ... + + @overload + def step(self, closure: Callable[[], float]) -> float: ... + + @torch.no_grad() # type: ignore[misc] + @override + def step(self, closure: Callable[[], float] | None = None) -> float | None: + """Performs a single optimization step. + + Args: + closure: Unsupported; must be ``None``. + """ + if closure is not None: + raise ValueError("closure is not supported") + + for group in self.param_groups: + self._init_group(group) + + lr = group["lr"] + momentum = group["momentum"] + weight_decay = group["weight_decay"] + for p in group["params"]: + if p.grad is None: + continue # pragma: no cover + + state = self.state[p] + state["step"] += 1 + g = state["g"] + v_norm = state["v_norm"] + + grad = p.grad.to(torch.float32) + weight = p.to(torch.float32) + + # Decompose grad_W into magnitude and direction gradients via the weight-norm Jacobian. + # u = v / ||v||_row is O(1) per element; reconstruct the direction v from the live weight. + u = weight / g + v = u * v_norm + grad_g = (grad * u).sum(dim=1, keepdim=True) + grad_v = (g / v_norm) * (grad - u * grad_g) + + # Direction: Muon update on v (EMA momentum + Newton-Schulz), reusing the parent LMO. + state["momentum_buffer"].lerp_(grad_v, 1 - momentum) + with utils.fp32_matmul_precision(self.fp32_matmul_prec): + direction_update = self.scaled_orthogonalize_fn(state["momentum_buffer"]) + v_new = v.add(direction_update, alpha=-lr) + + # Magnitude: Adam on g (the l-infinity duality map for the diagonal gain), same lr. + magnitude_update = update_functions.calculate_adam_update( + grad_g, + state["m_g"], + state["v_g"], + betas=self.betas, + eps=self.adam_eps, + correct_bias=True, + nesterov=False, + step=state["step"], + ) + g.add_(magnitude_update, alpha=-lr) + + # Decoupled weight decay on the magnitude (the spectral-norm driver). + if weight_decay != 0.0: + g.add_(g, alpha=-weight_decay * lr) + + # Recompose W = g * v_new / ||v_new||_row and refresh the cached direction norm. Recomposing + # from the decayed g keeps the invariant ||W_row|| == g without a separate resync. + v_norm_new = v_new.norm(dim=1, keepdim=True) + p.copy_(g * (v_new / v_norm_new)) + state["v_norm"] = v_norm_new + + return None From ba90c5a4da7c9374cdae25e183bcf5242469dc24 Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Wed, 24 Jun 2026 11:46:26 -0700 Subject: [PATCH 2/4] add AI written tests Signed-off-by: Hao Wu --- tests/muown_reference.py | 143 +++++++++++++++++++++++++++++++++++ tests/test_muown.py | 156 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 299 insertions(+) create mode 100644 tests/muown_reference.py create mode 100644 tests/test_muown.py diff --git a/tests/muown_reference.py b/tests/muown_reference.py new file mode 100644 index 00000000..8928cda1 --- /dev/null +++ b/tests/muown_reference.py @@ -0,0 +1,143 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Reference Muown implementation, adapted from the authors' code for use as a test oracle: +# https://github.com/kcc-lion/muown/blob/main/optim/muown.py +# Lion et al., "Muown: Row-Norm Control for Muon Optimization", arXiv:2605.10797 (paper: CC BY 4.0). +# The repository did not declare a code license at the time of copying. +"""Reference Muown implementation used as a test oracle.""" + +from typing import Callable + +import torch +from torch import Tensor +from torch.optim.optimizer import Optimizer + + +def _wn_pre_ns(W: Tensor, g: Tensor, v_norm: Tensor, grad_W: Tensor) -> tuple[Tensor, Tensor, Tensor]: + """Reconstruct direction v from (W, g, v_norm) and split grad_W into (grad_g, grad_v).""" + u = W / g + v = u * v_norm + grad_g = (grad_W * u).sum(dim=1, keepdim=True) + grad_v = (g / v_norm) * (grad_W - u * grad_g) + return v, grad_g, grad_v + + +def _wn_recompose(W: Tensor, g: Tensor, v_new: Tensor) -> Tensor: + """Write W = g * v_new / ||v_new||_row in place and return the new row norms.""" + v_norm_new = v_new.norm(dim=1, keepdim=True) + W.copy_(g * (v_new / v_norm_new)) + return v_norm_new + + +class MuownReference(Optimizer): + """Single-process reference Muown with an injected orthogonalization callable.""" + + def __init__( + self, + params, + orthogonalize_fn: Callable[[Tensor], Tensor], + lr: float = 3e-4, + momentum: float = 0.95, + nesterov: bool = False, + betas: tuple[float, float] = (0.9, 0.95), + weight_decay: float = 0.0, + adam_eps: float = 1e-8, + ) -> None: + self._orthogonalize_fn = orthogonalize_fn + defaults = dict( + lr=lr, + momentum=momentum, + nesterov=nesterov, + betas=betas, + weight_decay=weight_decay, + adam_eps=adam_eps, + ) + super().__init__(params, defaults) + + def _init_state_2d(self, p: Tensor, state: dict) -> None: + w_norm = p.data.norm(dim=1, keepdim=True) + state["g"] = w_norm.clone() + state["v_norm"] = w_norm.clone() + state["m_v"] = torch.zeros_like(p.data) + state["m_g"] = torch.zeros_like(w_norm) + state["v_g"] = torch.zeros_like(w_norm) + state["step"] = 0 + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + lr = group["lr"] + momentum = group["momentum"] + nesterov = group["nesterov"] + betas = group["betas"] + weight_decay = group["weight_decay"] + adam_eps = group["adam_eps"] + + for p in group["params"]: + if p.grad is None: + continue + + grad = p.grad + state = self.state[p] + + if len(state) == 0: + self._init_state_2d(p, state) + + state["step"] += 1 + step = state["step"] + + g = state["g"] + v_norm = state["v_norm"] + m_v = state["m_v"] + m_g = state["m_g"] + v_g = state["v_g"] + if weight_decay != 0.0: + W_old = p.data.clone() + + # Fused: reconstruct v + compute weight norm gradients + v, grad_g, grad_v = _wn_pre_ns(p.data, g, v_norm, grad) + + # Muon update on v: momentum + orthogonalization + m_v.mul_(momentum).add_(grad_v) + if nesterov: + update = grad_v.add(m_v, alpha=momentum) + else: + update = m_v.clone() + + # Injected orthogonalization (folds in the 0.2 * sqrt(max(m, n)) scaling). + update = self._orthogonalize_fn(update) + v_new = v.add(update, alpha=-lr) + + # Adam update on g (small [out_features, 1] vectors) + beta1, beta2 = betas + m_g.mul_(beta1).add_(grad_g, alpha=1 - beta1) + v_g.mul_(beta2).addcmul_(grad_g, grad_g, value=1 - beta2) + bc1 = 1 - beta1**step + bc2 = 1 - beta2**step + g.addcdiv_(m_g / bc1, (v_g / bc2).sqrt().add_(adam_eps), value=-lr) + + # Fused: recompose W = g * v_new / ||v_new||, writes directly into p.data + state["v_norm"] = _wn_recompose(p.data, g, v_new) + if weight_decay != 0.0: + p.data.add_(W_old, alpha=-lr * weight_decay) + g.copy_(p.data.norm(dim=1, keepdim=True)) + + return loss diff --git a/tests/test_muown.py b/tests/test_muown.py new file mode 100644 index 00000000..e0ea0fa1 --- /dev/null +++ b/tests/test_muown.py @@ -0,0 +1,156 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from absl import flags, logging +from absl.testing import absltest, parameterized +from muown_reference import MuownReference + +from emerging_optimizers.orthogonalized_optimizers.muown import Muown + + +flags.DEFINE_enum("device", "cpu", ["cpu", "cuda"], "Device to run tests on") +flags.DEFINE_integer("seed", None, "Random seed for reproducible tests") +FLAGS = flags.FLAGS + + +def setUpModule() -> None: + if FLAGS.seed is not None: + logging.info("Setting random seed to %d", FLAGS.seed) + torch.manual_seed(FLAGS.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(FLAGS.seed) + # Use the most precise fp32 matmul path so the reference (which calls the optimizer's + # orthogonalization outside the fp32_matmul_precision context) computes bit-comparable results. + torch.set_float32_matmul_precision("highest") + + +class MuownTest(parameterized.TestCase): + def setUp(self): + super().setUp() + self.device = FLAGS.device + + @parameterized.product(shape=[(8, 16), (16, 8), (33, 65)]) + def test_smoke(self, shape): + """A few steps run and keep the weight finite.""" + p = torch.nn.Parameter(torch.randn(shape, device=self.device)) + opt = Muown([p], lr=1e-2, weight_decay=0.01) + for _ in range(3): + p.grad = torch.randn_like(p) + opt.step() + self.assertTrue(torch.isfinite(p).all()) + + def test_raises_on_non_2d(self): + """Muown supports 2D parameters only.""" + for shape in [(8,), (2, 3, 4)]: + p = torch.nn.Parameter(torch.randn(shape, device=self.device)) + p.grad = torch.randn_like(p) + opt = Muown([p], lr=1e-2) + with self.assertRaises(TypeError): + opt.step() + + def test_raises_on_closure(self): + """Closures are not supported.""" + p = torch.nn.Parameter(torch.randn((8, 16), device=self.device)) + p.grad = torch.randn_like(p) + opt = Muown([p], lr=1e-2) + with self.assertRaises(ValueError): + opt.step(lambda: 0.0) + + @parameterized.product(shape=[(8, 16), (16, 8)], weight_decay=[0.0, 0.1]) + def test_row_norm_equals_magnitude_state(self, shape, weight_decay): + """The reparameterization invariant ||W_row|| == g holds after each step.""" + p = torch.nn.Parameter(torch.randn(shape, device=self.device)) + opt = Muown([p], lr=1e-2, weight_decay=weight_decay) + for _ in range(4): + p.grad = torch.randn_like(p) + opt.step() + row_norm = p.detach().norm(dim=1, keepdim=True) + torch.testing.assert_close( + row_norm, + opt.state[p]["g"], + atol=1e-5, + rtol=1e-5, + msg=lambda m: f"Row norms of W must equal the magnitude state g.\n\n{m}", + ) + + def test_weight_decay_shrinks_magnitude(self): + """Decoupled weight decay shrinks the magnitude g relative to the no-decay run.""" + p_wd = torch.nn.Parameter(torch.randn((16, 32), device=self.device)) + p_no = torch.nn.Parameter(p_wd.detach().clone()) + opt_wd = Muown([p_wd], lr=1e-2, weight_decay=0.1) + opt_no = Muown([p_no], lr=1e-2, weight_decay=0.0) + for _ in range(5): + grad = torch.randn_like(p_wd) + p_wd.grad = grad.clone() + p_no.grad = grad.clone() + opt_wd.step() + opt_no.step() + self.assertLess(opt_wd.state[p_wd]["g"].sum().item(), opt_no.state[p_no]["g"].sum().item()) + + @parameterized.product(shape=[(8, 16), (16, 8), (33, 65)], momentum=[0.0, 0.95]) + def test_agrees_with_reference(self, shape, momentum): + """Muown matches the authors' reference implementation (no weight decay). + + The reference uses classic (heavy-ball) momentum while Muown uses EMA momentum; the two differ by + a constant factor that the scale-invariant Newton-Schulz orthogonalization removes, so the updates + agree. Both use the same injected orthogonalization, so only float rounding (lerp vs mul/add, + Newton-Schulz normalization) separates them — hence a tight tolerance rather than bit-identity. + """ + p = torch.nn.Parameter(torch.randn(shape, device=self.device)) + p_ref = torch.nn.Parameter(p.detach().clone()) + + opt = Muown( + [p], + lr=1e-2, + momentum=momentum, + weight_decay=0.0, + extra_scale_factor=0.2, + fp32_matmul_prec="highest", + ) + # Hold orthogonalization identical: feed the reference Muown's own scaled_orthogonalize_fn. + opt_ref = MuownReference( + [p_ref], + orthogonalize_fn=opt.scaled_orthogonalize_fn, + lr=1e-2, + momentum=momentum, + nesterov=False, + weight_decay=0.0, + ) + + for _ in range(5): + grad = torch.randn_like(p) + p.grad = grad.clone() + p_ref.grad = grad.clone() + opt.step() + opt_ref.step() + + torch.testing.assert_close( + p.detach(), + p_ref.detach(), + atol=1e-5, + rtol=1e-4, + msg=lambda m: f"Muown weight diverged from the reference implementation.\n\n{m}", + ) + torch.testing.assert_close( + opt.state[p]["g"], + opt_ref.state[p_ref]["g"], + atol=1e-5, + rtol=1e-4, + msg=lambda m: f"Muown magnitude g diverged from the reference implementation.\n\n{m}", + ) + + +if __name__ == "__main__": + absltest.main() From 22831af6d6b57db062796ecd1c26e6d41fabe5be Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Wed, 24 Jun 2026 13:22:32 -0700 Subject: [PATCH 3/4] improve tests Signed-off-by: Hao Wu --- .../orthogonalized_optimizers/muown.py | 41 +--------------- tests/test_muown.py | 49 ++++--------------- 2 files changed, 10 insertions(+), 80 deletions(-) diff --git a/emerging_optimizers/orthogonalized_optimizers/muown.py b/emerging_optimizers/orthogonalized_optimizers/muown.py index 28f7ce7e..8247cc84 100644 --- a/emerging_optimizers/orthogonalized_optimizers/muown.py +++ b/emerging_optimizers/orthogonalized_optimizers/muown.py @@ -40,37 +40,6 @@ class Muown(Muon): replacement for :class:`~emerging_optimizers.orthogonalized_optimizers.muon.Muon` that splits each 2D weight into a per-row magnitude and a direction, then optimizes them under their natural geometries: - - **Direction** ``v``: the Muon update (EMA momentum + Newton-Schulz orthogonalization), reusing the - parent's ``scaled_orthogonalize_fn``. - - **Magnitude** ``g`` (one entry per output row): Adam, which realizes the :math:`\\ell_\\infty` duality - map for the diagonal per-neuron gain. - - The reparameterization ``W = Diag(g / ||v||_row) v`` is held implicitly inside the optimizer, so the - forward pass is unchanged. At init ``g``, ``v`` are seeded from ``W`` so Muown starts from the same - point as Muon. The row magnitude is the empirical driver of spectral-norm drift under Muon; making it an - explicit, separately-optimized variable controls that drift without the indiscriminate shrinkage of - plain weight decay. - - A single ``lr`` drives both halves: the direction step carries the ``0.2 * sqrt(max(m, n))`` scaling - (via ``extra_scale_factor``) that matches Adam's update RMS norm, so no separate magnitude lr is needed. - - State per parameter: - - ``step`` - - ``g``: per-row magnitude, shape ``[out, 1]``. - - ``v_norm``: cached row norms ``||v||_row`` of the direction, shape ``[out, 1]``. - - ``momentum_buffer``: EMA momentum of the direction gradient (the Muon momentum on ``v``). - - ``m_g``, ``v_g``: Adam first/second moments of the magnitude gradient. - - Note: - Weight decay is decoupled and applied to the magnitude ``g`` (which is exactly the spectral-norm - driver). Because ``W`` is recomposed from ``g`` after the decay, the invariant ``||W_row|| == g`` - holds without a separate resync. - - Warning: - - This optimizer requires that all parameters passed in are 2D. - - It should not be used for the embedding layer, the final fully connected layer, or any 1-D - parameters; those can all be optimized by a standard method (e.g., AdamW). - - This optimizer is experimental and may change in future versions. Args: params: Iterable of parameters to optimize or dicts defining parameter groups. @@ -182,20 +151,16 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: grad = p.grad.to(torch.float32) weight = p.to(torch.float32) - # Decompose grad_W into magnitude and direction gradients via the weight-norm Jacobian. - # u = v / ||v||_row is O(1) per element; reconstruct the direction v from the live weight. u = weight / g v = u * v_norm grad_g = (grad * u).sum(dim=1, keepdim=True) grad_v = (g / v_norm) * (grad - u * grad_g) - # Direction: Muon update on v (EMA momentum + Newton-Schulz), reusing the parent LMO. state["momentum_buffer"].lerp_(grad_v, 1 - momentum) with utils.fp32_matmul_precision(self.fp32_matmul_prec): direction_update = self.scaled_orthogonalize_fn(state["momentum_buffer"]) v_new = v.add(direction_update, alpha=-lr) - # Magnitude: Adam on g (the l-infinity duality map for the diagonal gain), same lr. magnitude_update = update_functions.calculate_adam_update( grad_g, state["m_g"], @@ -208,12 +173,8 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: ) g.add_(magnitude_update, alpha=-lr) - # Decoupled weight decay on the magnitude (the spectral-norm driver). - if weight_decay != 0.0: - g.add_(g, alpha=-weight_decay * lr) + g.add_(g, alpha=-weight_decay * lr) - # Recompose W = g * v_new / ||v_new||_row and refresh the cached direction norm. Recomposing - # from the decayed g keeps the invariant ||W_row|| == g without a separate resync. v_norm_new = v_new.norm(dim=1, keepdim=True) p.copy_(g * (v_new / v_norm_new)) state["v_norm"] = v_norm_new diff --git a/tests/test_muown.py b/tests/test_muown.py index e0ea0fa1..82453f5d 100644 --- a/tests/test_muown.py +++ b/tests/test_muown.py @@ -31,9 +31,6 @@ def setUpModule() -> None: torch.manual_seed(FLAGS.seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(FLAGS.seed) - # Use the most precise fp32 matmul path so the reference (which calls the optimizer's - # orthogonalization outside the fp32_matmul_precision context) computes bit-comparable results. - torch.set_float32_matmul_precision("highest") class MuownTest(parameterized.TestCase): @@ -43,7 +40,6 @@ def setUp(self): @parameterized.product(shape=[(8, 16), (16, 8), (33, 65)]) def test_smoke(self, shape): - """A few steps run and keep the weight finite.""" p = torch.nn.Parameter(torch.randn(shape, device=self.device)) opt = Muown([p], lr=1e-2, weight_decay=0.01) for _ in range(3): @@ -52,7 +48,6 @@ def test_smoke(self, shape): self.assertTrue(torch.isfinite(p).all()) def test_raises_on_non_2d(self): - """Muown supports 2D parameters only.""" for shape in [(8,), (2, 3, 4)]: p = torch.nn.Parameter(torch.randn(shape, device=self.device)) p.grad = torch.randn_like(p) @@ -61,7 +56,6 @@ def test_raises_on_non_2d(self): opt.step() def test_raises_on_closure(self): - """Closures are not supported.""" p = torch.nn.Parameter(torch.randn((8, 16), device=self.device)) p.grad = torch.randn_like(p) opt = Muown([p], lr=1e-2) @@ -70,7 +64,6 @@ def test_raises_on_closure(self): @parameterized.product(shape=[(8, 16), (16, 8)], weight_decay=[0.0, 0.1]) def test_row_norm_equals_magnitude_state(self, shape, weight_decay): - """The reparameterization invariant ||W_row|| == g holds after each step.""" p = torch.nn.Parameter(torch.randn(shape, device=self.device)) opt = Muown([p], lr=1e-2, weight_decay=weight_decay) for _ in range(4): @@ -80,50 +73,28 @@ def test_row_norm_equals_magnitude_state(self, shape, weight_decay): torch.testing.assert_close( row_norm, opt.state[p]["g"], - atol=1e-5, + atol=0, rtol=1e-5, - msg=lambda m: f"Row norms of W must equal the magnitude state g.\n\n{m}", ) - def test_weight_decay_shrinks_magnitude(self): - """Decoupled weight decay shrinks the magnitude g relative to the no-decay run.""" - p_wd = torch.nn.Parameter(torch.randn((16, 32), device=self.device)) - p_no = torch.nn.Parameter(p_wd.detach().clone()) - opt_wd = Muown([p_wd], lr=1e-2, weight_decay=0.1) - opt_no = Muown([p_no], lr=1e-2, weight_decay=0.0) - for _ in range(5): - grad = torch.randn_like(p_wd) - p_wd.grad = grad.clone() - p_no.grad = grad.clone() - opt_wd.step() - opt_no.step() - self.assertLess(opt_wd.state[p_wd]["g"].sum().item(), opt_no.state[p_no]["g"].sum().item()) - @parameterized.product(shape=[(8, 16), (16, 8), (33, 65)], momentum=[0.0, 0.95]) - def test_agrees_with_reference(self, shape, momentum): - """Muown matches the authors' reference implementation (no weight decay). - - The reference uses classic (heavy-ball) momentum while Muown uses EMA momentum; the two differ by - a constant factor that the scale-invariant Newton-Schulz orthogonalization removes, so the updates - agree. Both use the same injected orthogonalization, so only float rounding (lerp vs mul/add, - Newton-Schulz normalization) separates them — hence a tight tolerance rather than bit-identity. - """ + def test_close_reference(self, shape, momentum): p = torch.nn.Parameter(torch.randn(shape, device=self.device)) p_ref = torch.nn.Parameter(p.detach().clone()) opt = Muown( [p], - lr=1e-2, + lr=0.125, momentum=momentum, weight_decay=0.0, - extra_scale_factor=0.2, + extra_scale_factor=0.25, fp32_matmul_prec="highest", ) # Hold orthogonalization identical: feed the reference Muown's own scaled_orthogonalize_fn. opt_ref = MuownReference( [p_ref], orthogonalize_fn=opt.scaled_orthogonalize_fn, - lr=1e-2, + lr=0.125, momentum=momentum, nesterov=False, weight_decay=0.0, @@ -139,16 +110,14 @@ def test_agrees_with_reference(self, shape, momentum): torch.testing.assert_close( p.detach(), p_ref.detach(), - atol=1e-5, - rtol=1e-4, - msg=lambda m: f"Muown weight diverged from the reference implementation.\n\n{m}", + atol=1e-6, + rtol=1e-5, ) torch.testing.assert_close( opt.state[p]["g"], opt_ref.state[p_ref]["g"], - atol=1e-5, - rtol=1e-4, - msg=lambda m: f"Muown magnitude g diverged from the reference implementation.\n\n{m}", + atol=1e-7, + rtol=1e-5, ) From 837034803c1f74304fbd1866ffd17a57349ebad1 Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Wed, 24 Jun 2026 16:19:59 -0700 Subject: [PATCH 4/4] fix some minor issues Signed-off-by: Hao Wu --- .../orthogonalized_optimizers/muown.py | 42 ++++++++++++++----- tests/test_muown.py | 7 ++++ 2 files changed, 39 insertions(+), 10 deletions(-) diff --git a/emerging_optimizers/orthogonalized_optimizers/muown.py b/emerging_optimizers/orthogonalized_optimizers/muown.py index 8247cc84..e1cd321e 100644 --- a/emerging_optimizers/orthogonalized_optimizers/muown.py +++ b/emerging_optimizers/orthogonalized_optimizers/muown.py @@ -32,6 +32,31 @@ __all__ = ["Muown"] +@torch.compile +def _weight_norm_decompose( + weight: torch.Tensor, + grad: torch.Tensor, + g: torch.Tensor, + v_norm: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + r"""Reconstructs the direction and splits the gradient under the weight-norm reparameterization. + + Args: + weight: The current 2D weight ``W``. + grad: The gradient ``grad_W`` with respect to ``W``. + g: Per-row magnitude, shape ``[rows, 1]``. + v_norm: Cached row norms ``||v||_row`` of the direction, shape ``[rows, 1]``. + + Returns: + ``(v, grad_g, grad_v)``: the reconstructed direction and the magnitude and direction gradients. + """ + u = weight / g + v = u * v_norm + grad_g = (grad * u).sum(dim=1, keepdim=True) + grad_v = (g / v_norm) * (grad - u * grad_g) + return v, grad_g, grad_v + + @registry.register_optimizer("muown") class Muown(Muon): """Muown: Muon with internal weight normalization (row-norm control). @@ -105,8 +130,10 @@ def _init_group(self, group: dict, skip_non_grad_params: bool = True) -> None: raise TypeError("Muown is only supported for 2D parameters") state = self.state[p] if len(state) == 0: - # Seed g, v from the current weight so Muown starts from the same point as Muon. - row_norm = p.norm(dim=1, keepdim=True).to(torch.float32) + # Seed g, v from the current weight so Muown starts from the same point as Muon. Floor the + # row norm so an all-zero weight row does not give g=0, which would make u = weight / g a + # 0/0 NaN on the first step. + row_norm = p.norm(dim=1, keepdim=True).to(torch.float32).clamp_min(1e-12) state["step"] = 0 state["g"] = row_norm.clone() state["v_norm"] = row_norm.clone() @@ -148,13 +175,7 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: g = state["g"] v_norm = state["v_norm"] - grad = p.grad.to(torch.float32) - weight = p.to(torch.float32) - - u = weight / g - v = u * v_norm - grad_g = (grad * u).sum(dim=1, keepdim=True) - grad_v = (g / v_norm) * (grad - u * grad_g) + v, grad_g, grad_v = _weight_norm_decompose(p.to(torch.float32), p.grad.to(torch.float32), g, v_norm) state["momentum_buffer"].lerp_(grad_v, 1 - momentum) with utils.fp32_matmul_precision(self.fp32_matmul_prec): @@ -173,7 +194,8 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: ) g.add_(magnitude_update, alpha=-lr) - g.add_(g, alpha=-weight_decay * lr) + # Decoupled weight decay on the magnitude (the spectral-norm driver). + self._apply_weight_decay_inplace(g, grad_g, lr, weight_decay) v_norm_new = v_new.norm(dim=1, keepdim=True) p.copy_(g * (v_new / v_norm_new)) diff --git a/tests/test_muown.py b/tests/test_muown.py index 82453f5d..c82c7d2a 100644 --- a/tests/test_muown.py +++ b/tests/test_muown.py @@ -79,6 +79,13 @@ def test_row_norm_equals_magnitude_state(self, shape, weight_decay): @parameterized.product(shape=[(8, 16), (16, 8), (33, 65)], momentum=[0.0, 0.95]) def test_close_reference(self, shape, momentum): + """Muown matches the reference even though their momentum conventions differ. + + Muown uses EMA momentum (``buf = m*buf + (1-m)*grad_v``) while the reference uses heavy-ball + (``buf = m*buf + grad_v``). The two buffers differ only by the constant factor ``(1 - m)``, which + the scale-invariant Newton-Schulz orthogonalization removes, so the direction updates agree (to + Newton-Schulz's eps-normalization tolerance). Both feed the same orthogonalization function. + """ p = torch.nn.Parameter(torch.randn(shape, device=self.device)) p_ref = torch.nn.Parameter(p.detach().clone())