diff --git a/emerging_optimizers/orthogonalized_optimizers/muon_hyperball.py b/emerging_optimizers/orthogonalized_optimizers/muon_hyperball.py index c9a8c09..84adac9 100644 --- a/emerging_optimizers/orthogonalized_optimizers/muon_hyperball.py +++ b/emerging_optimizers/orthogonalized_optimizers/muon_hyperball.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from typing import Any, override import torch @@ -35,8 +34,8 @@ class MuonHyperball(muon.Muon): W_{t+1} = R \\cdot \\text{normalize}(W_t - \\text{lr} \\cdot R \\cdot \\text{normalize}(\\text{update})) - where :math:`R` is the Frobenius norm of :math:`W_t` (or a user-specified radius). This keeps - the weight matrix at constant scale while updating. + where :math:`R` is the user-specified Frobenius norm. This keeps the weight matrix at + constant scale while updating. Warning: This optimizer is experimental and may change in future versions. @@ -47,56 +46,59 @@ class MuonHyperball(muon.Muon): Args: *args: Arguments passed to Muon. + hyperball_radius: Fixed radius for the hyperball. All parameters must + already have this Frobenius norm at construction time. hyperball_eps: Epsilon for numerical stability in normalization. - Default: ``1e-8``. - hyperball_radius: Fixed radius for the hyperball. If ``None`` (default), - uses each parameter's initial Frobenius norm as its radius. If specified, all - parameters will be rescaled to have this radius at initialization. **kwargs: Keyword arguments passed to Muon. + Raises: + ValueError: If any parameter has zero norm, or if a parameter's + Frobenius norm does not match ``hyperball_radius``. + """ def __init__( self, *args: Any, - hyperball_eps: float = 1e-8, - hyperball_radius: float | None = None, + hyperball_radius: float, + hyperball_eps: float = 1e-15, **kwargs: Any, ) -> None: self.hyperball_eps = hyperball_eps self.hyperball_radius = hyperball_radius super().__init__(*args, **kwargs) - # Validate and optionally rescale parameters based on hyperball_radius. with torch.no_grad(): for group in self.param_groups: for p in group["params"]: p_norm = p.norm() - # Validate that parameter has non-zero norm. - if p_norm.item() == 0: + if p_norm <= hyperball_eps: # p_norm is non-negative, abs() is not needed + raise ValueError( + "MuonHyperball requires all parameters to have non-zero norm. " + "Found parameter with almost zero norm." + ) + if not torch.isclose( + p_norm, + torch.tensor(self.hyperball_radius, dtype=p_norm.dtype, device=p_norm.device), + atol=0, + rtol=1e-5, + ): raise ValueError( - "MuonHyperball requires all parameters to have non-zero norm. Found parameter with zero norm." + f"hyperball_radius={self.hyperball_radius} was specified but a parameter " + f"has Frobenius norm {p_norm.item()}. Rescale your model parameters to the " + f"desired radius before constructing the optimizer." ) - # Rescale parameter to have the specified radius if provided. - if self.hyperball_radius is not None: - p.mul_(self.hyperball_radius / p_norm.clamp_min(self.hyperball_eps)) @override def pre_weight_update_fn_inplace(self, p: torch.Tensor, update: torch.Tensor) -> None: - """Store the original weight norm and normalize the update using Frobenius norm. + """Normalize the update using Frobenius norm, scaled by R. Args: p: The parameter tensor. update: The orthogonalized gradient tensor. """ - # Use user-specified radius or compute R = ||W_t||_F (Frobenius norm) - R = self.hyperball_radius if self.hyperball_radius is not None else p.norm().item() - self.state[p]["hyperball_R"] = R - - # Normalize the update in-place and scale by R - # This modifies update to be: R * normalize(update) using Frobenius norm. update_norm = update.norm().clamp_min(self.hyperball_eps) - update.mul_(R / update_norm) + update.mul_(self.hyperball_radius / update_norm) @override def post_weight_update_fn_inplace(self, p: torch.Tensor) -> None: @@ -105,9 +107,6 @@ def post_weight_update_fn_inplace(self, p: torch.Tensor) -> None: Args: p: The parameter tensor (already updated). """ - # Retrieve R from per-parameter state - R = self.state[p]["hyperball_R"] - # Normalize the result and scale back by R: p = R * (p / ||p||_F) using Frobenius norm. p_norm = p.norm().clamp_min(self.hyperball_eps) - p.mul_(R / p_norm) + p.mul_(self.hyperball_radius / p_norm) diff --git a/tests/test_orthogonalized_optimizer.py b/tests/test_orthogonalized_optimizer.py index 252252a..7ebb0f9 100644 --- a/tests/test_orthogonalized_optimizer.py +++ b/tests/test_orthogonalized_optimizer.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import torch import torch.nn as nn from absl import flags, logging @@ -377,6 +376,7 @@ def test_norm_preservation(self, shape) -> None: lr=0.01, momentum=0.0, weight_decay=0.0, + hyperball_radius=initial_norm, ) # Run multiple steps with random gradients @@ -392,47 +392,20 @@ def test_norm_preservation(self, shape) -> None: rtol=1e-5, ) - @parameterized.product( - shape=[(5, 7), (33, 65), (127, 257)], - hyperball_radius=[0.5, 1.0, 2.0], - ) - def test_hyperball_radius_rescales_params(self, shape, hyperball_radius) -> None: - """Test that hyperball_radius kwarg rescales parameters to specified radius.""" - test_param = nn.Parameter(torch.randn(shape, dtype=torch.float32, device=self.device)) - - opt = muon_hyperball.MuonHyperball( - [test_param], - lr=0.01, - hyperball_radius=hyperball_radius, - ) - - # After initialization, parameter should have the specified radius - torch.testing.assert_close( - test_param.norm(), - torch.tensor(hyperball_radius, device=self.device), - atol=1e-5, - rtol=1e-5, - ) - - # Run multiple steps with random gradients - for _ in range(5): - test_param.grad = torch.randn_like(test_param) - opt.step() + def test_zero_norm_raises_error(self) -> None: + test_param = nn.Parameter(torch.zeros((5, 7), device=self.device)) - # Norm should remain at hyperball_radius after each step - torch.testing.assert_close( - test_param.norm(), - torch.tensor(hyperball_radius, device=self.device), - atol=1e-5, - rtol=1e-5, - ) + with self.assertRaises(ValueError): + muon_hyperball.MuonHyperball([test_param], lr=0.01, hyperball_radius=1.0) - def test_zero_norm_raises_error(self) -> None: - """Test that MuonHyperball raises ValueError for zero-norm parameters.""" - test_param = nn.Parameter(torch.zeros((5, 7), dtype=torch.float32, device=self.device)) + def test_radius_mismatch_raises_error(self) -> None: + """Test that MuonHyperball raises ValueError when a parameter's norm does not match the radius.""" + test_param = nn.Parameter(torch.randn((5, 7), dtype=torch.float32, device=self.device)) + # Pick a radius that differs from the parameter's actual Frobenius norm. + mismatched_radius = test_param.norm().item() + 1.0 with self.assertRaises(ValueError): - muon_hyperball.MuonHyperball([test_param], lr=0.01) + muon_hyperball.MuonHyperball([test_param], lr=0.01, hyperball_radius=mismatched_radius) class PolarGradTest(parameterized.TestCase):