diff --git a/emerging_optimizers/soap/__init__.py b/emerging_optimizers/soap/__init__.py index 928bdef..d8e333d 100644 --- a/emerging_optimizers/soap/__init__.py +++ b/emerging_optimizers/soap/__init__.py @@ -12,11 +12,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from emerging_optimizers.soap.moso import MOSO from emerging_optimizers.soap.rekls import REKLS from emerging_optimizers.soap.soap import SOAP __all__ = [ + "MOSO", "REKLS", "SOAP", ] diff --git a/emerging_optimizers/soap/moso.py b/emerging_optimizers/soap/moso.py new file mode 100644 index 0000000..241b54e --- /dev/null +++ b/emerging_optimizers/soap/moso.py @@ -0,0 +1,297 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING, Callable, override + + +if TYPE_CHECKING: + from typing import overload + +import torch +from torch import optim +from torch.optim.optimizer import ParamsT + +from emerging_optimizers import mixin as opt_mixin +from emerging_optimizers import registry, utils +from emerging_optimizers.scalar_optimizers import update_functions +from emerging_optimizers.soap import soap_utils +from emerging_optimizers.soap.soap import _clip_update_rms_in_place + + +__all__ = ["MOSO"] + + +@registry.register_optimizer("moso") +class MOSO(opt_mixin.WeightDecayMixin, optim.Optimizer): + r"""Momentum One-Sided SOAP. + + MOSO tracks EMA momentum like Muon, accumulates a SOAP/Shampoo-style covariance of that momentum on the + smaller matrix side, and applies an Adam update in the covariance eigenbasis. + Conceptually, this is one-sided SOAP where ``G_t G_t^T`` is replaced by ``M_t M_t^T`` (or ``M_t^T M_t`` for the + right side), and the update is computed by projecting the momentum into the eigenbasis, applying Adam there, and + projecting back: + + .. math:: + + C_t = \beta_s C_{t-1} + (1 - \beta_s) M_t M_t^T,\quad C_t = Q_M \Lambda_M Q_M^T + + U_t = Q_M \operatorname{Adam}(Q_M^T M_t) + + for the left-preconditioned case where ``M_t.shape[0] <= M_t.shape[1]``; the right-preconditioned case uses + ``C_t = M_t^T M_t`` and computes ``U_t = \operatorname{Adam}(M_t Q_M) Q_M^T``. + + Args: + params: Iterable of parameters to optimize or dicts defining parameter groups. + lr: Learning rate. + momentum: EMA coefficient for the Muon-style momentum. + betas: Inner Adam beta parameters ``(beta1, beta2)``. + shampoo_beta: EMA coefficient for the one-sided momentum covariance. + eps: Inner Adam epsilon for numerical stability. + weight_decay: Weight decay coefficient. + max_update_rms: Clip the update RMS to this value (0 means no clipping). + """ + + def __init__( + self, + params: ParamsT, + lr: float = 3e-4, + momentum: float = 0.95, + betas: tuple[float, float] = (0.9, 0.95), + shampoo_beta: float = 0.95, + eps: float = 1e-8, + weight_decay: float = 0.01, + *, + max_update_rms: float = 0.0, + ) -> None: + self.weight_decay_method = "decoupled" + self.max_update_rms = max_update_rms + + defaults = { + "lr": lr, + "momentum": momentum, + "betas": betas, + "shampoo_beta": shampoo_beta, + "eps": eps, + "weight_decay": weight_decay, + } + super().__init__(params, defaults) + + @torch.no_grad() # type: ignore[misc] + def _init_group( + self, + group: dict, + skip_non_grad_params: bool = True, + ) -> None: + """Performs lazy state initialization for 2D parameters.""" + for p in group["params"]: + if skip_non_grad_params and p.grad is None: + continue + + if p.dim() != 2: + raise TypeError("MOSO is only supported for 2D tensors") + + state = self.state[p] + if len(state) == 0: + rows, cols = p.shape + preconditioner_size = min(rows, cols) + state["step"] = 0 + state["momentum_buffer"] = torch.zeros_like(p.data, dtype=torch.float32) + state["exp_avg"] = torch.zeros_like(p.data, dtype=torch.float32) + state["exp_avg_sq"] = torch.zeros_like(p.data, dtype=torch.float32) + state["M"] = torch.zeros( + preconditioner_size, + preconditioner_size, + device=p.device, + dtype=torch.float32, + ) + state["Q_M"] = torch.eye(preconditioner_size, device=p.device, dtype=torch.float32) + + if TYPE_CHECKING: + + @overload + def step(self, closure: None = ...) -> None: ... + + @overload + def step(self, closure: Callable[[], float]) -> float: ... + + @torch.no_grad() # type: ignore[misc] + @override + def step(self, closure: Callable[[], float] | None = None) -> float | None: + """Performs a single optimization step.""" + if closure is not None: + raise ValueError("closure is not supported") + + for group in self.param_groups: + self._init_group(group) + + for p in group["params"]: + if p.grad is None: + continue # pragma: no cover + + grad = p.grad.to(torch.float32) + state = self.state[p] + curr_iter_1_based = state["step"] + 1 + + self._apply_weight_decay_inplace( + p, + grad, + group["lr"], + group["weight_decay"], + ) + + state["momentum_buffer"].lerp_(grad, 1 - group["momentum"]) + momentum = state["momentum_buffer"] + + shampoo_beta = 1 - (1 - group["shampoo_beta"]) / (1 - group["shampoo_beta"] ** curr_iter_1_based) + + with utils.fp32_matmul_precision("highest"): + _update_one_sided_momentum_factor( + momentum_factor=state["M"], + momentum=momentum, + shampoo_beta=shampoo_beta, + ) + + left_preconditioned = momentum.shape[0] <= momentum.shape[1] + with utils.fp32_matmul_precision("highest"): + state["Q_M"], state["exp_avg"], state["exp_avg_sq"] = _update_eigenbasis_and_adam_exp_avgs( + momentum_factor=state["M"], + eigenbasis=state["Q_M"], + exp_avg=state["exp_avg"], + exp_avg_sq=state["exp_avg_sq"], + left_preconditioned=left_preconditioned, + use_eigh=state["step"] == 0, + power_iter_steps=1, + ) + + with utils.fp32_matmul_precision("highest"): + momentum_projected = _project_to_one_sided_eigenbasis( + x=momentum, + eigenbasis=state["Q_M"], + left_preconditioned=left_preconditioned, + ) + adam_update = update_functions.calculate_adam_update( + momentum_projected, + state["exp_avg"], + state["exp_avg_sq"], + betas=group["betas"], + eps=group["eps"], + correct_bias=True, + nesterov=False, + step=curr_iter_1_based, + ) + update = _project_from_one_sided_eigenbasis( + x=adam_update, + eigenbasis=state["Q_M"], + left_preconditioned=left_preconditioned, + ) + + _clip_update_rms_in_place(update, self.max_update_rms) + p.add_(update, alpha=-group["lr"]) + + state["step"] += 1 + + return None + + +@torch.no_grad() # type: ignore[misc] +def _update_one_sided_momentum_factor( + momentum_factor: torch.Tensor, + momentum: torch.Tensor, + shampoo_beta: float, +) -> None: + """Update the smaller-side covariance of the Muon momentum.""" + left_preconditioned = momentum.shape[0] <= momentum.shape[1] + maybe_transposed_momentum = momentum if left_preconditioned else momentum.T + momentum_factor.lerp_(maybe_transposed_momentum @ maybe_transposed_momentum.T, 1 - shampoo_beta) + + +@torch.no_grad() # type: ignore[misc] +def _update_eigenbasis_and_adam_exp_avgs( + momentum_factor: torch.Tensor, + eigenbasis: torch.Tensor, + exp_avg: torch.Tensor, + exp_avg_sq: torch.Tensor, + left_preconditioned: bool, + *, + use_eigh: bool, + power_iter_steps: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Update one eigenbasis and keep Adam state aligned with that basis.""" + exp_avg = _project_from_one_sided_eigenbasis( + x=exp_avg, + eigenbasis=eigenbasis, + left_preconditioned=left_preconditioned, + ) + + eigenbasis, exp_avg_sq = _sort_one_sided_eigenbasis_and_exp_avg_sq( + momentum_factor=momentum_factor, + eigenbasis=eigenbasis, + exp_avg_sq=exp_avg_sq, + left_preconditioned=left_preconditioned, + ) + + if use_eigh: + (updated_eigenbasis,) = soap_utils.get_eigenbasis_eigh([momentum_factor]) + else: + (updated_eigenbasis,) = soap_utils.get_eigenbasis_qr( + [momentum_factor], + [eigenbasis], + power_iter_steps=power_iter_steps, + ) + + exp_avg = _project_to_one_sided_eigenbasis( + x=exp_avg, + eigenbasis=updated_eigenbasis, + left_preconditioned=left_preconditioned, + ) + return updated_eigenbasis, exp_avg, exp_avg_sq + + +@torch.no_grad() # type: ignore[misc] +def _sort_one_sided_eigenbasis_and_exp_avg_sq( + momentum_factor: torch.Tensor, + eigenbasis: torch.Tensor, + exp_avg_sq: torch.Tensor, + left_preconditioned: bool, +) -> tuple[torch.Tensor, torch.Tensor]: + """Sort eigenbasis slots by approximate eigenvalue and permute Adam second moments.""" + approx_eigvals = utils.eig.conjugate(momentum_factor, eigenbasis, diag=True) + sort_idx = torch.argsort(approx_eigvals, descending=True, stable=True) + sorted_eigenbasis = eigenbasis[:, sort_idx] + exp_avg_sq_dim = 0 if left_preconditioned else 1 + return sorted_eigenbasis, exp_avg_sq.index_select(exp_avg_sq_dim, sort_idx) + + +@torch.no_grad() # type: ignore[misc] +def _project_to_one_sided_eigenbasis( + x: torch.Tensor, + eigenbasis: torch.Tensor, + left_preconditioned: bool, +) -> torch.Tensor: + """Project a matrix into the smaller-side covariance eigenbasis.""" + if left_preconditioned: + return eigenbasis.T @ x + return x @ eigenbasis + + +@torch.no_grad() # type: ignore[misc] +def _project_from_one_sided_eigenbasis( + x: torch.Tensor, + eigenbasis: torch.Tensor, + left_preconditioned: bool, +) -> torch.Tensor: + """Project a matrix from the smaller-side covariance eigenbasis.""" + if left_preconditioned: + return eigenbasis @ x + return x @ eigenbasis.T diff --git a/tests/test_moso.py b/tests/test_moso.py new file mode 100644 index 0000000..2a94649 --- /dev/null +++ b/tests/test_moso.py @@ -0,0 +1,142 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from absl import flags, logging +from absl.testing import absltest, parameterized + +from emerging_optimizers import registry +from emerging_optimizers.soap import MOSO + + +flags.DEFINE_enum("device", "cpu", ["cpu", "cuda"], "Device to run tests on") +flags.DEFINE_integer("seed", None, "Random seed for reproducible tests") +FLAGS = flags.FLAGS + + +def setUpModule() -> None: + if FLAGS.seed is not None: + logging.info("Setting random seed to %d", FLAGS.seed) + torch.manual_seed(FLAGS.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(FLAGS.seed) + + +class MOSOTest(parameterized.TestCase): + @parameterized.parameters( # type: ignore[misc] + {"shape": (5, 3)}, + {"shape": (3, 5)}, + {"shape": (4, 4)}, + ) + def test_3steps_smoke(self, shape: tuple[int, int]) -> None: + param = torch.randn(shape, requires_grad=True, device=FLAGS.device) + optimizer = MOSO( + [param], + lr=0.001, + weight_decay=0.01, + momentum=0.9, + shampoo_beta=0.95, + ) + + for _ in range(3): + param.grad = torch.randn_like(param) + optimizer.step() + param.grad = None + + def test_registry(self) -> None: + self.assertIs(registry.get_optimizer_cls("moso"), MOSO) + + @parameterized.parameters( + {"shape": (3, 5)}, + {"shape": (5, 3)}, + ) + def test_accumulates_momentum_covariance_on_smaller_side(self, shape: tuple[int, int]) -> None: + grad = torch.randn(shape, device=FLAGS.device) + param = torch.zeros(shape, requires_grad=True, device=FLAGS.device) + param.grad = grad.clone() + optimizer = MOSO( + [param], + lr=0.0, + momentum=0.0, + shampoo_beta=0.0, + weight_decay=0.0, + ) + + optimizer.step() + + state = optimizer.state[param] + expected_momentum_factor = grad @ grad.T if shape[0] <= shape[1] else grad.T @ grad + torch.testing.assert_close( + state["M"], + expected_momentum_factor, + atol=1e-6, + rtol=1e-6, + msg=lambda msg: f"Momentum covariance mismatch for shape {shape}:\n{msg}", + ) + self.assertEqual(state["M"].shape, (min(shape), min(shape))) + self.assertEqual(state["Q_M"].shape, (min(shape), min(shape))) + + @parameterized.parameters( + {"shape": (4, 8)}, + {"shape": (8, 4)}, + ) + def test_no_ema_is_close_to_one_sided_adam_in_eigenbasis(self, shape: tuple[int, int]) -> None: + torch.manual_seed(7) + grad = torch.randn(shape, device=FLAGS.device) + param = torch.zeros(shape, requires_grad=True, device=FLAGS.device) + param.grad = grad.clone() + lr = 0.125 + optimizer = MOSO( + [param], + lr=lr, + momentum=0.0, + betas=(0.0, 0.0), + shampoo_beta=0.0, + eps=1e-12, + weight_decay=0.0, + ) + + optimizer.step() + + state = optimizer.state[param] + eigenbasis = state["Q_M"] + if shape[0] <= shape[1]: + projected = eigenbasis.T @ grad + adam_projected = projected / (projected.abs() + optimizer.param_groups[0]["eps"]) + expected_update = eigenbasis @ adam_projected + else: + projected = grad @ eigenbasis + adam_projected = projected / (projected.abs() + optimizer.param_groups[0]["eps"]) + expected_update = adam_projected @ eigenbasis.T + + applied_update = -param.detach() / lr + torch.testing.assert_close( + applied_update, + expected_update, + atol=0.0, + rtol=1e-4, + msg=lambda msg: f"MOSO no-EMA update did not match projected Adam update for shape {shape}:\n{msg}", + ) + + def test_non_2d_param_raises_type_error(self) -> None: + param = torch.randn(10, requires_grad=True, device=FLAGS.device) + optimizer = MOSO([param], lr=0.001) + param.grad = torch.randn_like(param) + + with self.assertRaisesRegex(TypeError, "only supported for 2D"): + optimizer.step() + + +if __name__ == "__main__": + absltest.main()