Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions emerging_optimizers/orthogonalized_optimizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from emerging_optimizers.orthogonalized_optimizers.mop import *
from emerging_optimizers.orthogonalized_optimizers.muon import *
from emerging_optimizers.orthogonalized_optimizers.muon_hyperball import *
from emerging_optimizers.orthogonalized_optimizers.muown import *
from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import *
from emerging_optimizers.orthogonalized_optimizers.polargrad import *
from emerging_optimizers.orthogonalized_optimizers.scion import *
Expand Down
204 changes: 204 additions & 0 deletions emerging_optimizers/orthogonalized_optimizers/muown.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, Callable, override


if TYPE_CHECKING:
from typing import overload

import torch
from torch.optim.optimizer import ParamsT

from emerging_optimizers import registry, utils
from emerging_optimizers.orthogonalized_optimizers.muon import Muon, MuonScaleT
from emerging_optimizers.orthogonalized_optimizers.muon_utils import NSCoeffT
from emerging_optimizers.scalar_optimizers import update_functions
from emerging_optimizers.utils import FP32MatmulPrecT


__all__ = ["Muown"]


@torch.compile
def _weight_norm_decompose(
weight: torch.Tensor,
grad: torch.Tensor,
g: torch.Tensor,
v_norm: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
r"""Reconstructs the direction and splits the gradient under the weight-norm reparameterization.

Args:
weight: The current 2D weight ``W``.
grad: The gradient ``grad_W`` with respect to ``W``.
g: Per-row magnitude, shape ``[rows, 1]``.
v_norm: Cached row norms ``||v||_row`` of the direction, shape ``[rows, 1]``.

Returns:
``(v, grad_g, grad_v)``: the reconstructed direction and the magnitude and direction gradients.
"""
u = weight / g
v = u * v_norm
grad_g = (grad * u).sum(dim=1, keepdim=True)
grad_v = (g / v_norm) * (grad - u * grad_g)
return v, grad_g, grad_v


@registry.register_optimizer("muown")
class Muown(Muon):
"""Muown: Muon with internal weight normalization (row-norm control).

Muown (Lion et al., *Muown: Row-Norm Control for Muon Optimization*, arXiv:2605.10797) is a drop-in
replacement for :class:`~emerging_optimizers.orthogonalized_optimizers.muon.Muon` that splits each 2D
weight into a per-row magnitude and a direction, then optimizes them under their natural geometries:


Args:
params: Iterable of parameters to optimize or dicts defining parameter groups.
lr: Learning rate shared by the direction and magnitude updates.
momentum: EMA momentum for the direction (Muon) update.
betas: Adam ``(beta1, beta2)`` for the magnitude update.
adam_eps: Adam epsilon for the magnitude update.
weight_decay: Decoupled weight decay coefficient, applied to the magnitude ``g``.
fp32_matmul_prec: Precision for the orthogonalization GEMM operations.
coefficient_type: Newton-Schulz coefficient set (see :class:`Muon`).
num_ns_steps: Number of Newton-Schulz iteration steps.
scale_mode: Update scale mode (see :func:`~emerging_optimizers.orthogonalized_optimizers.muon.get_muon_scale_factor`).
extra_scale_factor: Extra scale on the direction update; ``0.2`` matches Adam's update RMS norm.
use_syrk: Whether to use the Triton SYRK kernel for Newton-Schulz.
"""

def __init__(
self,
params: ParamsT,
lr: float = 3e-4,
momentum: float = 0.95,
weight_decay: float = 0.0,
*,
betas: tuple[float, float] = (0.9, 0.95),
adam_eps: float = 1e-8,
fp32_matmul_prec: FP32MatmulPrecT = "medium",
coefficient_type: NSCoeffT = "quintic",
num_ns_steps: int = 5,
scale_mode: MuonScaleT = "spectral",
extra_scale_factor: float = 1.0,
use_syrk: bool = False,
) -> None:
if not 0.0 <= betas[0] < 1.0:
raise ValueError(f"Invalid beta1: {betas[0]}")
if not 0.0 <= betas[1] < 1.0:
raise ValueError(f"Invalid beta2: {betas[1]}")

self.betas = betas
self.adam_eps = adam_eps

super().__init__(
params,
lr,
momentum,
weight_decay,
nesterov=False,
weight_decay_method="decoupled",
fp32_matmul_prec=fp32_matmul_prec,
coefficient_type=coefficient_type,
num_ns_steps=num_ns_steps,
scale_mode=scale_mode,
extra_scale_factor=extra_scale_factor,
use_syrk=use_syrk,
)

@torch.no_grad() # type: ignore[misc]
@override
def _init_group(self, group: dict, skip_non_grad_params: bool = True) -> None:
for p in group["params"]:
if skip_non_grad_params and p.grad is None:
continue
if p.dim() != 2:
raise TypeError("Muown is only supported for 2D parameters")
state = self.state[p]
if len(state) == 0:
# Seed g, v from the current weight so Muown starts from the same point as Muon. Floor the
# row norm so an all-zero weight row does not give g=0, which would make u = weight / g a
# 0/0 NaN on the first step.
row_norm = p.norm(dim=1, keepdim=True).to(torch.float32).clamp_min(1e-12)
state["step"] = 0
state["g"] = row_norm.clone()
state["v_norm"] = row_norm.clone()
state["momentum_buffer"] = torch.zeros_like(p, dtype=torch.float32)
state["m_g"] = torch.zeros_like(row_norm)
state["v_g"] = torch.zeros_like(row_norm)

if TYPE_CHECKING:

@overload
def step(self, closure: None = ...) -> None: ...

@overload
def step(self, closure: Callable[[], float]) -> float: ...

@torch.no_grad() # type: ignore[misc]
@override
def step(self, closure: Callable[[], float] | None = None) -> float | None:
"""Performs a single optimization step.

Args:
closure: Unsupported; must be ``None``.
"""
if closure is not None:
raise ValueError("closure is not supported")

for group in self.param_groups:
self._init_group(group)

lr = group["lr"]
momentum = group["momentum"]
weight_decay = group["weight_decay"]
for p in group["params"]:
if p.grad is None:
continue # pragma: no cover

state = self.state[p]
Comment thread
skyw marked this conversation as resolved.
state["step"] += 1
g = state["g"]
v_norm = state["v_norm"]

v, grad_g, grad_v = _weight_norm_decompose(p.to(torch.float32), p.grad.to(torch.float32), g, v_norm)

state["momentum_buffer"].lerp_(grad_v, 1 - momentum)
with utils.fp32_matmul_precision(self.fp32_matmul_prec):
direction_update = self.scaled_orthogonalize_fn(state["momentum_buffer"])
v_new = v.add(direction_update, alpha=-lr)

magnitude_update = update_functions.calculate_adam_update(
grad_g,
state["m_g"],
state["v_g"],
betas=self.betas,
eps=self.adam_eps,
correct_bias=True,
nesterov=False,
step=state["step"],
)
g.add_(magnitude_update, alpha=-lr)

# Decoupled weight decay on the magnitude (the spectral-norm driver).
self._apply_weight_decay_inplace(g, grad_g, lr, weight_decay)

v_norm_new = v_new.norm(dim=1, keepdim=True)
p.copy_(g * (v_new / v_norm_new))
state["v_norm"] = v_norm_new

return None
143 changes: 143 additions & 0 deletions tests/muown_reference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Reference Muown implementation, adapted from the authors' code for use as a test oracle:
# https://github.com/kcc-lion/muown/blob/main/optim/muown.py
# Lion et al., "Muown: Row-Norm Control for Muon Optimization", arXiv:2605.10797 (paper: CC BY 4.0).
# The repository did not declare a code license at the time of copying.
"""Reference Muown implementation used as a test oracle."""

from typing import Callable

import torch
from torch import Tensor
from torch.optim.optimizer import Optimizer


def _wn_pre_ns(W: Tensor, g: Tensor, v_norm: Tensor, grad_W: Tensor) -> tuple[Tensor, Tensor, Tensor]:
"""Reconstruct direction v from (W, g, v_norm) and split grad_W into (grad_g, grad_v)."""
u = W / g
v = u * v_norm
grad_g = (grad_W * u).sum(dim=1, keepdim=True)
grad_v = (g / v_norm) * (grad_W - u * grad_g)
return v, grad_g, grad_v


def _wn_recompose(W: Tensor, g: Tensor, v_new: Tensor) -> Tensor:
"""Write W = g * v_new / ||v_new||_row in place and return the new row norms."""
v_norm_new = v_new.norm(dim=1, keepdim=True)
W.copy_(g * (v_new / v_norm_new))
return v_norm_new


class MuownReference(Optimizer):
"""Single-process reference Muown with an injected orthogonalization callable."""

def __init__(
self,
params,
orthogonalize_fn: Callable[[Tensor], Tensor],
lr: float = 3e-4,
momentum: float = 0.95,
nesterov: bool = False,
betas: tuple[float, float] = (0.9, 0.95),
weight_decay: float = 0.0,
adam_eps: float = 1e-8,
) -> None:
self._orthogonalize_fn = orthogonalize_fn
defaults = dict(
lr=lr,
momentum=momentum,
nesterov=nesterov,
betas=betas,
weight_decay=weight_decay,
adam_eps=adam_eps,
)
super().__init__(params, defaults)

def _init_state_2d(self, p: Tensor, state: dict) -> None:
w_norm = p.data.norm(dim=1, keepdim=True)
state["g"] = w_norm.clone()
state["v_norm"] = w_norm.clone()
state["m_v"] = torch.zeros_like(p.data)
state["m_g"] = torch.zeros_like(w_norm)
state["v_g"] = torch.zeros_like(w_norm)
state["step"] = 0

@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()

for group in self.param_groups:
lr = group["lr"]
momentum = group["momentum"]
nesterov = group["nesterov"]
betas = group["betas"]
weight_decay = group["weight_decay"]
adam_eps = group["adam_eps"]

for p in group["params"]:
if p.grad is None:
continue

grad = p.grad
state = self.state[p]

if len(state) == 0:
self._init_state_2d(p, state)

state["step"] += 1
step = state["step"]

g = state["g"]
v_norm = state["v_norm"]
m_v = state["m_v"]
m_g = state["m_g"]
v_g = state["v_g"]
if weight_decay != 0.0:
W_old = p.data.clone()

# Fused: reconstruct v + compute weight norm gradients
v, grad_g, grad_v = _wn_pre_ns(p.data, g, v_norm, grad)

# Muon update on v: momentum + orthogonalization
m_v.mul_(momentum).add_(grad_v)
if nesterov:
update = grad_v.add(m_v, alpha=momentum)
else:
update = m_v.clone()

# Injected orthogonalization (folds in the 0.2 * sqrt(max(m, n)) scaling).
update = self._orthogonalize_fn(update)
v_new = v.add(update, alpha=-lr)

# Adam update on g (small [out_features, 1] vectors)
beta1, beta2 = betas
m_g.mul_(beta1).add_(grad_g, alpha=1 - beta1)
v_g.mul_(beta2).addcmul_(grad_g, grad_g, value=1 - beta2)
bc1 = 1 - beta1**step
Comment thread
skyw marked this conversation as resolved.
bc2 = 1 - beta2**step
g.addcdiv_(m_g / bc1, (v_g / bc2).sqrt().add_(adam_eps), value=-lr)

# Fused: recompose W = g * v_new / ||v_new||, writes directly into p.data
state["v_norm"] = _wn_recompose(p.data, g, v_new)
if weight_decay != 0.0:
p.data.add_(W_old, alpha=-lr * weight_decay)
g.copy_(p.data.norm(dim=1, keepdim=True))

return loss
Loading
Loading