From 40ad7e7b44782e397cd57c6120918762e9d1959b Mon Sep 17 00:00:00 2001 From: mikail Date: Tue, 7 Apr 2026 09:17:04 -0700 Subject: [PATCH 1/6] clean up hyperball Signed-off-by: mikail --- .../muon_hyperball.py | 48 +++++++++++-------- 1 file changed, 29 insertions(+), 19 deletions(-) diff --git a/emerging_optimizers/orthogonalized_optimizers/muon_hyperball.py b/emerging_optimizers/orthogonalized_optimizers/muon_hyperball.py index c9a8c09d..66c58151 100644 --- a/emerging_optimizers/orthogonalized_optimizers/muon_hyperball.py +++ b/emerging_optimizers/orthogonalized_optimizers/muon_hyperball.py @@ -35,8 +35,8 @@ class MuonHyperball(muon.Muon): W_{t+1} = R \\cdot \\text{normalize}(W_t - \\text{lr} \\cdot R \\cdot \\text{normalize}(\\text{update})) - where :math:`R` is the Frobenius norm of :math:`W_t` (or a user-specified radius). This keeps - the weight matrix at constant scale while updating. + where :math:`R` is the user-specified hyperball radius. This keeps the weight matrix at + constant scale while updating. Warning: This optimizer is experimental and may change in future versions. @@ -49,52 +49,62 @@ class MuonHyperball(muon.Muon): *args: Arguments passed to Muon. hyperball_eps: Epsilon for numerical stability in normalization. Default: ``1e-8``. - hyperball_radius: Fixed radius for the hyperball. If ``None`` (default), - uses each parameter's initial Frobenius norm as its radius. If specified, all - parameters will be rescaled to have this radius at initialization. + hyperball_radius: Fixed radius for the hyperball. All parameters must + already have this Frobenius norm at construction time. **kwargs: Keyword arguments passed to Muon. + Raises: + ValueError: If any parameter has zero norm, or if a parameter's + Frobenius norm does not match ``hyperball_radius``. + """ def __init__( self, *args: Any, hyperball_eps: float = 1e-8, - hyperball_radius: float | None = None, + hyperball_radius: float, **kwargs: Any, ) -> None: self.hyperball_eps = hyperball_eps self.hyperball_radius = hyperball_radius super().__init__(*args, **kwargs) - # Validate and optionally rescale parameters based on hyperball_radius. with torch.no_grad(): for group in self.param_groups: for p in group["params"]: p_norm = p.norm() - # Validate that parameter has non-zero norm. - if p_norm.item() == 0: + if p_norm == 0: + raise ValueError( + "MuonHyperball requires all parameters to have non-zero norm. " + "Found parameter with zero norm." + ) + if not torch.isclose( + p_norm, + torch.tensor(self.hyperball_radius, dtype=p.dtype, device=p.device), + rtol=1e-5, + atol=1e-8, + ): raise ValueError( - "MuonHyperball requires all parameters to have non-zero norm. Found parameter with zero norm." + f"hyperball_radius={self.hyperball_radius} was specified but a parameter " + f"has Frobenius norm {p_norm.item()}. Rescale your model parameters to the " + f"desired radius before constructing the optimizer." ) - # Rescale parameter to have the specified radius if provided. - if self.hyperball_radius is not None: - p.mul_(self.hyperball_radius / p_norm.clamp_min(self.hyperball_eps)) @override def pre_weight_update_fn_inplace(self, p: torch.Tensor, update: torch.Tensor) -> None: - """Store the original weight norm and normalize the update using Frobenius norm. + """Normalize the update using Frobenius norm, scaled by R. Args: p: The parameter tensor. update: The orthogonalized gradient tensor. """ - # Use user-specified radius or compute R = ||W_t||_F (Frobenius norm) - R = self.hyperball_radius if self.hyperball_radius is not None else p.norm().item() - self.state[p]["hyperball_R"] = R + if "hyperball_R" not in self.state[p]: + self.state[p]["hyperball_R"] = torch.tensor( + self.hyperball_radius, dtype=p.dtype, device=p.device + ) + R = self.state[p]["hyperball_R"] - # Normalize the update in-place and scale by R - # This modifies update to be: R * normalize(update) using Frobenius norm. update_norm = update.norm().clamp_min(self.hyperball_eps) update.mul_(R / update_norm) From 1a204b698ce718e8812798bda8116b648511ac21 Mon Sep 17 00:00:00 2001 From: mikail Date: Tue, 7 Apr 2026 13:29:35 -0700 Subject: [PATCH 2/6] clean up hyperball Signed-off-by: mikail --- .../muon_hyperball.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/emerging_optimizers/orthogonalized_optimizers/muon_hyperball.py b/emerging_optimizers/orthogonalized_optimizers/muon_hyperball.py index 66c58151..ef8a0651 100644 --- a/emerging_optimizers/orthogonalized_optimizers/muon_hyperball.py +++ b/emerging_optimizers/orthogonalized_optimizers/muon_hyperball.py @@ -121,3 +121,21 @@ def post_weight_update_fn_inplace(self, p: torch.Tensor) -> None: # Normalize the result and scale back by R: p = R * (p / ||p||_F) using Frobenius norm. p_norm = p.norm().clamp_min(self.hyperball_eps) p.mul_(R / p_norm) + + @staticmethod + def _compute_tangent_projection( + param: torch.Tensor, grad_like: torch.Tensor + ) -> torch.Tensor: + """Compute the Riemannian gradient via tangent-space projection. + Frobenius sphere (entire matrix on a single sphere). + + Args: + param: Parameter tensor (2D). + grad_like: Gradient-like tensor (momentum buffer or gradient). + + Returns: + The tangent-space projected gradient. + """ + + projection = (param * grad_like).sum() / param.pow(2).sum().clamp(min=1e-12) + return grad_like - projection * param From 7f3be7fe3f83fa84bf0c7473a523566db0ddb0e2 Mon Sep 17 00:00:00 2001 From: mikail Date: Tue, 7 Apr 2026 13:31:39 -0700 Subject: [PATCH 3/6] clean up hyperball, revert some comments Signed-off-by: mikail --- .../muon_hyperball.py | 19 +------------------ 1 file changed, 1 insertion(+), 18 deletions(-) diff --git a/emerging_optimizers/orthogonalized_optimizers/muon_hyperball.py b/emerging_optimizers/orthogonalized_optimizers/muon_hyperball.py index ef8a0651..1bde8dc8 100644 --- a/emerging_optimizers/orthogonalized_optimizers/muon_hyperball.py +++ b/emerging_optimizers/orthogonalized_optimizers/muon_hyperball.py @@ -35,7 +35,7 @@ class MuonHyperball(muon.Muon): W_{t+1} = R \\cdot \\text{normalize}(W_t - \\text{lr} \\cdot R \\cdot \\text{normalize}(\\text{update})) - where :math:`R` is the user-specified hyperball radius. This keeps the weight matrix at + where :math:`R` is the user-specified Frobenius norm. This keeps the weight matrix at constant scale while updating. Warning: @@ -122,20 +122,3 @@ def post_weight_update_fn_inplace(self, p: torch.Tensor) -> None: p_norm = p.norm().clamp_min(self.hyperball_eps) p.mul_(R / p_norm) - @staticmethod - def _compute_tangent_projection( - param: torch.Tensor, grad_like: torch.Tensor - ) -> torch.Tensor: - """Compute the Riemannian gradient via tangent-space projection. - Frobenius sphere (entire matrix on a single sphere). - - Args: - param: Parameter tensor (2D). - grad_like: Gradient-like tensor (momentum buffer or gradient). - - Returns: - The tangent-space projected gradient. - """ - - projection = (param * grad_like).sum() / param.pow(2).sum().clamp(min=1e-12) - return grad_like - projection * param From 1cc83c03b5cddee9210fe42e7744ef7ffb9c412f Mon Sep 17 00:00:00 2001 From: mikail Date: Tue, 7 Apr 2026 13:33:21 -0700 Subject: [PATCH 4/6] linting Signed-off-by: mikail --- .../orthogonalized_optimizers/muon_hyperball.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/emerging_optimizers/orthogonalized_optimizers/muon_hyperball.py b/emerging_optimizers/orthogonalized_optimizers/muon_hyperball.py index 1bde8dc8..8d75b747 100644 --- a/emerging_optimizers/orthogonalized_optimizers/muon_hyperball.py +++ b/emerging_optimizers/orthogonalized_optimizers/muon_hyperball.py @@ -100,9 +100,7 @@ def pre_weight_update_fn_inplace(self, p: torch.Tensor, update: torch.Tensor) -> update: The orthogonalized gradient tensor. """ if "hyperball_R" not in self.state[p]: - self.state[p]["hyperball_R"] = torch.tensor( - self.hyperball_radius, dtype=p.dtype, device=p.device - ) + self.state[p]["hyperball_R"] = torch.tensor(self.hyperball_radius, dtype=p.dtype, device=p.device) R = self.state[p]["hyperball_R"] update_norm = update.norm().clamp_min(self.hyperball_eps) @@ -121,4 +119,3 @@ def post_weight_update_fn_inplace(self, p: torch.Tensor) -> None: # Normalize the result and scale back by R: p = R * (p / ||p||_F) using Frobenius norm. p_norm = p.norm().clamp_min(self.hyperball_eps) p.mul_(R / p_norm) - From 5c1db4e14a5a5b9c229f42e452ab70f9bb42136c Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Tue, 23 Jun 2026 10:48:04 -0700 Subject: [PATCH 5/6] Make naming and interface more consistent Signed-off-by: Hao Wu --- .../muon_hyperball.py | 19 +++++---- tests/test_orthogonalized_optimizer.py | 39 +------------------ 2 files changed, 11 insertions(+), 47 deletions(-) diff --git a/emerging_optimizers/orthogonalized_optimizers/muon_hyperball.py b/emerging_optimizers/orthogonalized_optimizers/muon_hyperball.py index 8d75b747..1e94a548 100644 --- a/emerging_optimizers/orthogonalized_optimizers/muon_hyperball.py +++ b/emerging_optimizers/orthogonalized_optimizers/muon_hyperball.py @@ -48,7 +48,6 @@ class MuonHyperball(muon.Muon): Args: *args: Arguments passed to Muon. hyperball_eps: Epsilon for numerical stability in normalization. - Default: ``1e-8``. hyperball_radius: Fixed radius for the hyperball. All parameters must already have this Frobenius norm at construction time. **kwargs: Keyword arguments passed to Muon. @@ -62,8 +61,8 @@ class MuonHyperball(muon.Muon): def __init__( self, *args: Any, - hyperball_eps: float = 1e-8, hyperball_radius: float, + hyperball_eps: float = 1e-15, **kwargs: Any, ) -> None: self.hyperball_eps = hyperball_eps @@ -74,16 +73,16 @@ def __init__( for group in self.param_groups: for p in group["params"]: p_norm = p.norm() - if p_norm == 0: + if p_norm.abs() <= hyperball_eps: raise ValueError( "MuonHyperball requires all parameters to have non-zero norm. " - "Found parameter with zero norm." + "Found parameter with almost zero norm." ) if not torch.isclose( p_norm, - torch.tensor(self.hyperball_radius, dtype=p.dtype, device=p.device), + torch.tensor(self.hyperball_radius, dtype=p_norm.dtype, device=p_norm.device), + atol=0, rtol=1e-5, - atol=1e-8, ): raise ValueError( f"hyperball_radius={self.hyperball_radius} was specified but a parameter " @@ -99,9 +98,9 @@ def pre_weight_update_fn_inplace(self, p: torch.Tensor, update: torch.Tensor) -> p: The parameter tensor. update: The orthogonalized gradient tensor. """ - if "hyperball_R" not in self.state[p]: - self.state[p]["hyperball_R"] = torch.tensor(self.hyperball_radius, dtype=p.dtype, device=p.device) - R = self.state[p]["hyperball_R"] + if "hyperball_radius" not in self.state[p]: + self.state[p]["hyperball_radius"] = torch.tensor(self.hyperball_radius, dtype=p.dtype, device=p.device) + R = self.state[p]["hyperball_radius"] update_norm = update.norm().clamp_min(self.hyperball_eps) update.mul_(R / update_norm) @@ -114,7 +113,7 @@ def post_weight_update_fn_inplace(self, p: torch.Tensor) -> None: p: The parameter tensor (already updated). """ # Retrieve R from per-parameter state - R = self.state[p]["hyperball_R"] + R = self.state[p]["hyperball_radius"] # Normalize the result and scale back by R: p = R * (p / ||p||_F) using Frobenius norm. p_norm = p.norm().clamp_min(self.hyperball_eps) diff --git a/tests/test_orthogonalized_optimizer.py b/tests/test_orthogonalized_optimizer.py index 252252a0..8cd132ba 100644 --- a/tests/test_orthogonalized_optimizer.py +++ b/tests/test_orthogonalized_optimizer.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import torch import torch.nn as nn from absl import flags, logging @@ -377,6 +376,7 @@ def test_norm_preservation(self, shape) -> None: lr=0.01, momentum=0.0, weight_decay=0.0, + hyperball_radius=initial_norm, ) # Run multiple steps with random gradients @@ -392,47 +392,12 @@ def test_norm_preservation(self, shape) -> None: rtol=1e-5, ) - @parameterized.product( - shape=[(5, 7), (33, 65), (127, 257)], - hyperball_radius=[0.5, 1.0, 2.0], - ) - def test_hyperball_radius_rescales_params(self, shape, hyperball_radius) -> None: - """Test that hyperball_radius kwarg rescales parameters to specified radius.""" - test_param = nn.Parameter(torch.randn(shape, dtype=torch.float32, device=self.device)) - - opt = muon_hyperball.MuonHyperball( - [test_param], - lr=0.01, - hyperball_radius=hyperball_radius, - ) - - # After initialization, parameter should have the specified radius - torch.testing.assert_close( - test_param.norm(), - torch.tensor(hyperball_radius, device=self.device), - atol=1e-5, - rtol=1e-5, - ) - - # Run multiple steps with random gradients - for _ in range(5): - test_param.grad = torch.randn_like(test_param) - opt.step() - - # Norm should remain at hyperball_radius after each step - torch.testing.assert_close( - test_param.norm(), - torch.tensor(hyperball_radius, device=self.device), - atol=1e-5, - rtol=1e-5, - ) - def test_zero_norm_raises_error(self) -> None: """Test that MuonHyperball raises ValueError for zero-norm parameters.""" test_param = nn.Parameter(torch.zeros((5, 7), dtype=torch.float32, device=self.device)) with self.assertRaises(ValueError): - muon_hyperball.MuonHyperball([test_param], lr=0.01) + muon_hyperball.MuonHyperball([test_param], lr=0.01, hyperball_radius=1.0) class PolarGradTest(parameterized.TestCase): From f32adcf32532bb8e3c3a801bf387c38f174ce34d Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Tue, 23 Jun 2026 10:56:39 -0700 Subject: [PATCH 6/6] add test for raises Signed-off-by: Hao Wu --- .../orthogonalized_optimizers/muon_hyperball.py | 16 ++++------------ tests/test_orthogonalized_optimizer.py | 12 ++++++++++-- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/emerging_optimizers/orthogonalized_optimizers/muon_hyperball.py b/emerging_optimizers/orthogonalized_optimizers/muon_hyperball.py index 1e94a548..84adac99 100644 --- a/emerging_optimizers/orthogonalized_optimizers/muon_hyperball.py +++ b/emerging_optimizers/orthogonalized_optimizers/muon_hyperball.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from typing import Any, override import torch @@ -47,9 +46,9 @@ class MuonHyperball(muon.Muon): Args: *args: Arguments passed to Muon. - hyperball_eps: Epsilon for numerical stability in normalization. hyperball_radius: Fixed radius for the hyperball. All parameters must already have this Frobenius norm at construction time. + hyperball_eps: Epsilon for numerical stability in normalization. **kwargs: Keyword arguments passed to Muon. Raises: @@ -73,7 +72,7 @@ def __init__( for group in self.param_groups: for p in group["params"]: p_norm = p.norm() - if p_norm.abs() <= hyperball_eps: + if p_norm <= hyperball_eps: # p_norm is non-negative, abs() is not needed raise ValueError( "MuonHyperball requires all parameters to have non-zero norm. " "Found parameter with almost zero norm." @@ -98,12 +97,8 @@ def pre_weight_update_fn_inplace(self, p: torch.Tensor, update: torch.Tensor) -> p: The parameter tensor. update: The orthogonalized gradient tensor. """ - if "hyperball_radius" not in self.state[p]: - self.state[p]["hyperball_radius"] = torch.tensor(self.hyperball_radius, dtype=p.dtype, device=p.device) - R = self.state[p]["hyperball_radius"] - update_norm = update.norm().clamp_min(self.hyperball_eps) - update.mul_(R / update_norm) + update.mul_(self.hyperball_radius / update_norm) @override def post_weight_update_fn_inplace(self, p: torch.Tensor) -> None: @@ -112,9 +107,6 @@ def post_weight_update_fn_inplace(self, p: torch.Tensor) -> None: Args: p: The parameter tensor (already updated). """ - # Retrieve R from per-parameter state - R = self.state[p]["hyperball_radius"] - # Normalize the result and scale back by R: p = R * (p / ||p||_F) using Frobenius norm. p_norm = p.norm().clamp_min(self.hyperball_eps) - p.mul_(R / p_norm) + p.mul_(self.hyperball_radius / p_norm) diff --git a/tests/test_orthogonalized_optimizer.py b/tests/test_orthogonalized_optimizer.py index 8cd132ba..7ebb0f9d 100644 --- a/tests/test_orthogonalized_optimizer.py +++ b/tests/test_orthogonalized_optimizer.py @@ -393,12 +393,20 @@ def test_norm_preservation(self, shape) -> None: ) def test_zero_norm_raises_error(self) -> None: - """Test that MuonHyperball raises ValueError for zero-norm parameters.""" - test_param = nn.Parameter(torch.zeros((5, 7), dtype=torch.float32, device=self.device)) + test_param = nn.Parameter(torch.zeros((5, 7), device=self.device)) with self.assertRaises(ValueError): muon_hyperball.MuonHyperball([test_param], lr=0.01, hyperball_radius=1.0) + def test_radius_mismatch_raises_error(self) -> None: + """Test that MuonHyperball raises ValueError when a parameter's norm does not match the radius.""" + test_param = nn.Parameter(torch.randn((5, 7), dtype=torch.float32, device=self.device)) + # Pick a radius that differs from the parameter's actual Frobenius norm. + mismatched_radius = test_param.norm().item() + 1.0 + + with self.assertRaises(ValueError): + muon_hyperball.MuonHyperball([test_param], lr=0.01, hyperball_radius=mismatched_radius) + class PolarGradTest(parameterized.TestCase): def setUp(self):