diff --git a/emerging_optimizers/orthogonalized_optimizers/__init__.py b/emerging_optimizers/orthogonalized_optimizers/__init__.py index 1e381a7..8c78408 100644 --- a/emerging_optimizers/orthogonalized_optimizers/__init__.py +++ b/emerging_optimizers/orthogonalized_optimizers/__init__.py @@ -16,6 +16,7 @@ from emerging_optimizers.orthogonalized_optimizers.mop import * from emerging_optimizers.orthogonalized_optimizers.muon import * from emerging_optimizers.orthogonalized_optimizers.muon_hyperball import * +from emerging_optimizers.orthogonalized_optimizers.muown import * from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import * from emerging_optimizers.orthogonalized_optimizers.polargrad import * from emerging_optimizers.orthogonalized_optimizers.scion import * diff --git a/emerging_optimizers/orthogonalized_optimizers/muown.py b/emerging_optimizers/orthogonalized_optimizers/muown.py new file mode 100644 index 0000000..e1cd321 --- /dev/null +++ b/emerging_optimizers/orthogonalized_optimizers/muown.py @@ -0,0 +1,204 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Callable, override + + +if TYPE_CHECKING: + from typing import overload + +import torch +from torch.optim.optimizer import ParamsT + +from emerging_optimizers import registry, utils +from emerging_optimizers.orthogonalized_optimizers.muon import Muon, MuonScaleT +from emerging_optimizers.orthogonalized_optimizers.muon_utils import NSCoeffT +from emerging_optimizers.scalar_optimizers import update_functions +from emerging_optimizers.utils import FP32MatmulPrecT + + +__all__ = ["Muown"] + + +@torch.compile +def _weight_norm_decompose( + weight: torch.Tensor, + grad: torch.Tensor, + g: torch.Tensor, + v_norm: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + r"""Reconstructs the direction and splits the gradient under the weight-norm reparameterization. + + Args: + weight: The current 2D weight ``W``. + grad: The gradient ``grad_W`` with respect to ``W``. + g: Per-row magnitude, shape ``[rows, 1]``. + v_norm: Cached row norms ``||v||_row`` of the direction, shape ``[rows, 1]``. + + Returns: + ``(v, grad_g, grad_v)``: the reconstructed direction and the magnitude and direction gradients. + """ + u = weight / g + v = u * v_norm + grad_g = (grad * u).sum(dim=1, keepdim=True) + grad_v = (g / v_norm) * (grad - u * grad_g) + return v, grad_g, grad_v + + +@registry.register_optimizer("muown") +class Muown(Muon): + """Muown: Muon with internal weight normalization (row-norm control). + + Muown (Lion et al., *Muown: Row-Norm Control for Muon Optimization*, arXiv:2605.10797) is a drop-in + replacement for :class:`~emerging_optimizers.orthogonalized_optimizers.muon.Muon` that splits each 2D + weight into a per-row magnitude and a direction, then optimizes them under their natural geometries: + + + Args: + params: Iterable of parameters to optimize or dicts defining parameter groups. + lr: Learning rate shared by the direction and magnitude updates. + momentum: EMA momentum for the direction (Muon) update. + betas: Adam ``(beta1, beta2)`` for the magnitude update. + adam_eps: Adam epsilon for the magnitude update. + weight_decay: Decoupled weight decay coefficient, applied to the magnitude ``g``. + fp32_matmul_prec: Precision for the orthogonalization GEMM operations. + coefficient_type: Newton-Schulz coefficient set (see :class:`Muon`). + num_ns_steps: Number of Newton-Schulz iteration steps. + scale_mode: Update scale mode (see :func:`~emerging_optimizers.orthogonalized_optimizers.muon.get_muon_scale_factor`). + extra_scale_factor: Extra scale on the direction update; ``0.2`` matches Adam's update RMS norm. + use_syrk: Whether to use the Triton SYRK kernel for Newton-Schulz. + """ + + def __init__( + self, + params: ParamsT, + lr: float = 3e-4, + momentum: float = 0.95, + weight_decay: float = 0.0, + *, + betas: tuple[float, float] = (0.9, 0.95), + adam_eps: float = 1e-8, + fp32_matmul_prec: FP32MatmulPrecT = "medium", + coefficient_type: NSCoeffT = "quintic", + num_ns_steps: int = 5, + scale_mode: MuonScaleT = "spectral", + extra_scale_factor: float = 1.0, + use_syrk: bool = False, + ) -> None: + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f"Invalid beta1: {betas[0]}") + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f"Invalid beta2: {betas[1]}") + + self.betas = betas + self.adam_eps = adam_eps + + super().__init__( + params, + lr, + momentum, + weight_decay, + nesterov=False, + weight_decay_method="decoupled", + fp32_matmul_prec=fp32_matmul_prec, + coefficient_type=coefficient_type, + num_ns_steps=num_ns_steps, + scale_mode=scale_mode, + extra_scale_factor=extra_scale_factor, + use_syrk=use_syrk, + ) + + @torch.no_grad() # type: ignore[misc] + @override + def _init_group(self, group: dict, skip_non_grad_params: bool = True) -> None: + for p in group["params"]: + if skip_non_grad_params and p.grad is None: + continue + if p.dim() != 2: + raise TypeError("Muown is only supported for 2D parameters") + state = self.state[p] + if len(state) == 0: + # Seed g, v from the current weight so Muown starts from the same point as Muon. Floor the + # row norm so an all-zero weight row does not give g=0, which would make u = weight / g a + # 0/0 NaN on the first step. + row_norm = p.norm(dim=1, keepdim=True).to(torch.float32).clamp_min(1e-12) + state["step"] = 0 + state["g"] = row_norm.clone() + state["v_norm"] = row_norm.clone() + state["momentum_buffer"] = torch.zeros_like(p, dtype=torch.float32) + state["m_g"] = torch.zeros_like(row_norm) + state["v_g"] = torch.zeros_like(row_norm) + + if TYPE_CHECKING: + + @overload + def step(self, closure: None = ...) -> None: ... + + @overload + def step(self, closure: Callable[[], float]) -> float: ... + + @torch.no_grad() # type: ignore[misc] + @override + def step(self, closure: Callable[[], float] | None = None) -> float | None: + """Performs a single optimization step. + + Args: + closure: Unsupported; must be ``None``. + """ + if closure is not None: + raise ValueError("closure is not supported") + + for group in self.param_groups: + self._init_group(group) + + lr = group["lr"] + momentum = group["momentum"] + weight_decay = group["weight_decay"] + for p in group["params"]: + if p.grad is None: + continue # pragma: no cover + + state = self.state[p] + state["step"] += 1 + g = state["g"] + v_norm = state["v_norm"] + + v, grad_g, grad_v = _weight_norm_decompose(p.to(torch.float32), p.grad.to(torch.float32), g, v_norm) + + state["momentum_buffer"].lerp_(grad_v, 1 - momentum) + with utils.fp32_matmul_precision(self.fp32_matmul_prec): + direction_update = self.scaled_orthogonalize_fn(state["momentum_buffer"]) + v_new = v.add(direction_update, alpha=-lr) + + magnitude_update = update_functions.calculate_adam_update( + grad_g, + state["m_g"], + state["v_g"], + betas=self.betas, + eps=self.adam_eps, + correct_bias=True, + nesterov=False, + step=state["step"], + ) + g.add_(magnitude_update, alpha=-lr) + + # Decoupled weight decay on the magnitude (the spectral-norm driver). + self._apply_weight_decay_inplace(g, grad_g, lr, weight_decay) + + v_norm_new = v_new.norm(dim=1, keepdim=True) + p.copy_(g * (v_new / v_norm_new)) + state["v_norm"] = v_norm_new + + return None diff --git a/tests/muown_reference.py b/tests/muown_reference.py new file mode 100644 index 0000000..8928cda --- /dev/null +++ b/tests/muown_reference.py @@ -0,0 +1,143 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Reference Muown implementation, adapted from the authors' code for use as a test oracle: +# https://github.com/kcc-lion/muown/blob/main/optim/muown.py +# Lion et al., "Muown: Row-Norm Control for Muon Optimization", arXiv:2605.10797 (paper: CC BY 4.0). +# The repository did not declare a code license at the time of copying. +"""Reference Muown implementation used as a test oracle.""" + +from typing import Callable + +import torch +from torch import Tensor +from torch.optim.optimizer import Optimizer + + +def _wn_pre_ns(W: Tensor, g: Tensor, v_norm: Tensor, grad_W: Tensor) -> tuple[Tensor, Tensor, Tensor]: + """Reconstruct direction v from (W, g, v_norm) and split grad_W into (grad_g, grad_v).""" + u = W / g + v = u * v_norm + grad_g = (grad_W * u).sum(dim=1, keepdim=True) + grad_v = (g / v_norm) * (grad_W - u * grad_g) + return v, grad_g, grad_v + + +def _wn_recompose(W: Tensor, g: Tensor, v_new: Tensor) -> Tensor: + """Write W = g * v_new / ||v_new||_row in place and return the new row norms.""" + v_norm_new = v_new.norm(dim=1, keepdim=True) + W.copy_(g * (v_new / v_norm_new)) + return v_norm_new + + +class MuownReference(Optimizer): + """Single-process reference Muown with an injected orthogonalization callable.""" + + def __init__( + self, + params, + orthogonalize_fn: Callable[[Tensor], Tensor], + lr: float = 3e-4, + momentum: float = 0.95, + nesterov: bool = False, + betas: tuple[float, float] = (0.9, 0.95), + weight_decay: float = 0.0, + adam_eps: float = 1e-8, + ) -> None: + self._orthogonalize_fn = orthogonalize_fn + defaults = dict( + lr=lr, + momentum=momentum, + nesterov=nesterov, + betas=betas, + weight_decay=weight_decay, + adam_eps=adam_eps, + ) + super().__init__(params, defaults) + + def _init_state_2d(self, p: Tensor, state: dict) -> None: + w_norm = p.data.norm(dim=1, keepdim=True) + state["g"] = w_norm.clone() + state["v_norm"] = w_norm.clone() + state["m_v"] = torch.zeros_like(p.data) + state["m_g"] = torch.zeros_like(w_norm) + state["v_g"] = torch.zeros_like(w_norm) + state["step"] = 0 + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + lr = group["lr"] + momentum = group["momentum"] + nesterov = group["nesterov"] + betas = group["betas"] + weight_decay = group["weight_decay"] + adam_eps = group["adam_eps"] + + for p in group["params"]: + if p.grad is None: + continue + + grad = p.grad + state = self.state[p] + + if len(state) == 0: + self._init_state_2d(p, state) + + state["step"] += 1 + step = state["step"] + + g = state["g"] + v_norm = state["v_norm"] + m_v = state["m_v"] + m_g = state["m_g"] + v_g = state["v_g"] + if weight_decay != 0.0: + W_old = p.data.clone() + + # Fused: reconstruct v + compute weight norm gradients + v, grad_g, grad_v = _wn_pre_ns(p.data, g, v_norm, grad) + + # Muon update on v: momentum + orthogonalization + m_v.mul_(momentum).add_(grad_v) + if nesterov: + update = grad_v.add(m_v, alpha=momentum) + else: + update = m_v.clone() + + # Injected orthogonalization (folds in the 0.2 * sqrt(max(m, n)) scaling). + update = self._orthogonalize_fn(update) + v_new = v.add(update, alpha=-lr) + + # Adam update on g (small [out_features, 1] vectors) + beta1, beta2 = betas + m_g.mul_(beta1).add_(grad_g, alpha=1 - beta1) + v_g.mul_(beta2).addcmul_(grad_g, grad_g, value=1 - beta2) + bc1 = 1 - beta1**step + bc2 = 1 - beta2**step + g.addcdiv_(m_g / bc1, (v_g / bc2).sqrt().add_(adam_eps), value=-lr) + + # Fused: recompose W = g * v_new / ||v_new||, writes directly into p.data + state["v_norm"] = _wn_recompose(p.data, g, v_new) + if weight_decay != 0.0: + p.data.add_(W_old, alpha=-lr * weight_decay) + g.copy_(p.data.norm(dim=1, keepdim=True)) + + return loss diff --git a/tests/test_muown.py b/tests/test_muown.py new file mode 100644 index 0000000..c82c7d2 --- /dev/null +++ b/tests/test_muown.py @@ -0,0 +1,132 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from absl import flags, logging +from absl.testing import absltest, parameterized +from muown_reference import MuownReference + +from emerging_optimizers.orthogonalized_optimizers.muown import Muown + + +flags.DEFINE_enum("device", "cpu", ["cpu", "cuda"], "Device to run tests on") +flags.DEFINE_integer("seed", None, "Random seed for reproducible tests") +FLAGS = flags.FLAGS + + +def setUpModule() -> None: + if FLAGS.seed is not None: + logging.info("Setting random seed to %d", FLAGS.seed) + torch.manual_seed(FLAGS.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(FLAGS.seed) + + +class MuownTest(parameterized.TestCase): + def setUp(self): + super().setUp() + self.device = FLAGS.device + + @parameterized.product(shape=[(8, 16), (16, 8), (33, 65)]) + def test_smoke(self, shape): + p = torch.nn.Parameter(torch.randn(shape, device=self.device)) + opt = Muown([p], lr=1e-2, weight_decay=0.01) + for _ in range(3): + p.grad = torch.randn_like(p) + opt.step() + self.assertTrue(torch.isfinite(p).all()) + + def test_raises_on_non_2d(self): + for shape in [(8,), (2, 3, 4)]: + p = torch.nn.Parameter(torch.randn(shape, device=self.device)) + p.grad = torch.randn_like(p) + opt = Muown([p], lr=1e-2) + with self.assertRaises(TypeError): + opt.step() + + def test_raises_on_closure(self): + p = torch.nn.Parameter(torch.randn((8, 16), device=self.device)) + p.grad = torch.randn_like(p) + opt = Muown([p], lr=1e-2) + with self.assertRaises(ValueError): + opt.step(lambda: 0.0) + + @parameterized.product(shape=[(8, 16), (16, 8)], weight_decay=[0.0, 0.1]) + def test_row_norm_equals_magnitude_state(self, shape, weight_decay): + p = torch.nn.Parameter(torch.randn(shape, device=self.device)) + opt = Muown([p], lr=1e-2, weight_decay=weight_decay) + for _ in range(4): + p.grad = torch.randn_like(p) + opt.step() + row_norm = p.detach().norm(dim=1, keepdim=True) + torch.testing.assert_close( + row_norm, + opt.state[p]["g"], + atol=0, + rtol=1e-5, + ) + + @parameterized.product(shape=[(8, 16), (16, 8), (33, 65)], momentum=[0.0, 0.95]) + def test_close_reference(self, shape, momentum): + """Muown matches the reference even though their momentum conventions differ. + + Muown uses EMA momentum (``buf = m*buf + (1-m)*grad_v``) while the reference uses heavy-ball + (``buf = m*buf + grad_v``). The two buffers differ only by the constant factor ``(1 - m)``, which + the scale-invariant Newton-Schulz orthogonalization removes, so the direction updates agree (to + Newton-Schulz's eps-normalization tolerance). Both feed the same orthogonalization function. + """ + p = torch.nn.Parameter(torch.randn(shape, device=self.device)) + p_ref = torch.nn.Parameter(p.detach().clone()) + + opt = Muown( + [p], + lr=0.125, + momentum=momentum, + weight_decay=0.0, + extra_scale_factor=0.25, + fp32_matmul_prec="highest", + ) + # Hold orthogonalization identical: feed the reference Muown's own scaled_orthogonalize_fn. + opt_ref = MuownReference( + [p_ref], + orthogonalize_fn=opt.scaled_orthogonalize_fn, + lr=0.125, + momentum=momentum, + nesterov=False, + weight_decay=0.0, + ) + + for _ in range(5): + grad = torch.randn_like(p) + p.grad = grad.clone() + p_ref.grad = grad.clone() + opt.step() + opt_ref.step() + + torch.testing.assert_close( + p.detach(), + p_ref.detach(), + atol=1e-6, + rtol=1e-5, + ) + torch.testing.assert_close( + opt.state[p]["g"], + opt_ref.state[p_ref]["g"], + atol=1e-7, + rtol=1e-5, + ) + + +if __name__ == "__main__": + absltest.main()