Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 27 additions & 28 deletions emerging_optimizers/orthogonalized_optimizers/muon_hyperball.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, override

import torch
Expand All @@ -35,8 +34,8 @@ class MuonHyperball(muon.Muon):

W_{t+1} = R \\cdot \\text{normalize}(W_t - \\text{lr} \\cdot R \\cdot \\text{normalize}(\\text{update}))

where :math:`R` is the Frobenius norm of :math:`W_t` (or a user-specified radius). This keeps
the weight matrix at constant scale while updating.
where :math:`R` is the user-specified Frobenius norm. This keeps the weight matrix at
constant scale while updating.

Warning:
This optimizer is experimental and may change in future versions.
Expand All @@ -47,56 +46,59 @@ class MuonHyperball(muon.Muon):

Args:
*args: Arguments passed to Muon.
hyperball_radius: Fixed radius for the hyperball. All parameters must
already have this Frobenius norm at construction time.
hyperball_eps: Epsilon for numerical stability in normalization.
Default: ``1e-8``.
hyperball_radius: Fixed radius for the hyperball. If ``None`` (default),
uses each parameter's initial Frobenius norm as its radius. If specified, all
parameters will be rescaled to have this radius at initialization.
**kwargs: Keyword arguments passed to Muon.

Raises:
ValueError: If any parameter has zero norm, or if a parameter's
Frobenius norm does not match ``hyperball_radius``.

"""

def __init__(
self,
*args: Any,
hyperball_eps: float = 1e-8,
hyperball_radius: float | None = None,
hyperball_radius: float,
hyperball_eps: float = 1e-15,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 hyperball_eps default weakens float32 numerical protection

The default dropped from 1e-8 to 1e-15. clamp_min(self.hyperball_eps) is used in both pre_weight_update_fn_inplace and post_weight_update_fn_inplace to guard against division by near-zero norms. Float32 machine epsilon is ~1.19e-7, so a gradient norm like 1e-10 is well above 1e-15 and would not be clamped, while the old 1e-8 would have clamped it. The result is that a near-zero gradient norm can now produce a scaling factor (hyperball_radius / update_norm) that is up to 10^7× larger than before, potentially creating very large intermediate values during the weight-update step before post_weight_update_fn_inplace projects back to the sphere. The PR description notes eps handling is being revisited across the repo — worth tracking this alongside that broader change.

**kwargs: Any,
) -> None:
self.hyperball_eps = hyperball_eps
self.hyperball_radius = hyperball_radius
super().__init__(*args, **kwargs)

# Validate and optionally rescale parameters based on hyperball_radius.
with torch.no_grad():
for group in self.param_groups:
for p in group["params"]:
p_norm = p.norm()
# Validate that parameter has non-zero norm.
if p_norm.item() == 0:
if p_norm <= hyperball_eps: # p_norm is non-negative, abs() is not needed
raise ValueError(
"MuonHyperball requires all parameters to have non-zero norm. "
"Found parameter with almost zero norm."
)
if not torch.isclose(
p_norm,
torch.tensor(self.hyperball_radius, dtype=p_norm.dtype, device=p_norm.device),
atol=0,
rtol=1e-5,
):
raise ValueError(
"MuonHyperball requires all parameters to have non-zero norm. Found parameter with zero norm."
f"hyperball_radius={self.hyperball_radius} was specified but a parameter "
f"has Frobenius norm {p_norm.item()}. Rescale your model parameters to the "
f"desired radius before constructing the optimizer."
)
# Rescale parameter to have the specified radius if provided.
if self.hyperball_radius is not None:
p.mul_(self.hyperball_radius / p_norm.clamp_min(self.hyperball_eps))

@override
def pre_weight_update_fn_inplace(self, p: torch.Tensor, update: torch.Tensor) -> None:
"""Store the original weight norm and normalize the update using Frobenius norm.
"""Normalize the update using Frobenius norm, scaled by R.

Args:
p: The parameter tensor.
update: The orthogonalized gradient tensor.
"""
# Use user-specified radius or compute R = ||W_t||_F (Frobenius norm)
R = self.hyperball_radius if self.hyperball_radius is not None else p.norm().item()
self.state[p]["hyperball_R"] = R

# Normalize the update in-place and scale by R
# This modifies update to be: R * normalize(update) using Frobenius norm.
update_norm = update.norm().clamp_min(self.hyperball_eps)
update.mul_(R / update_norm)
update.mul_(self.hyperball_radius / update_norm)

@override
def post_weight_update_fn_inplace(self, p: torch.Tensor) -> None:
Expand All @@ -105,9 +107,6 @@ def post_weight_update_fn_inplace(self, p: torch.Tensor) -> None:
Args:
p: The parameter tensor (already updated).
"""
# Retrieve R from per-parameter state
R = self.state[p]["hyperball_R"]

# Normalize the result and scale back by R: p = R * (p / ||p||_F) using Frobenius norm.
p_norm = p.norm().clamp_min(self.hyperball_eps)
p.mul_(R / p_norm)
p.mul_(self.hyperball_radius / p_norm)
49 changes: 11 additions & 38 deletions tests/test_orthogonalized_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.nn as nn
from absl import flags, logging
Expand Down Expand Up @@ -377,6 +376,7 @@ def test_norm_preservation(self, shape) -> None:
lr=0.01,
momentum=0.0,
weight_decay=0.0,
hyperball_radius=initial_norm,
)

# Run multiple steps with random gradients
Expand All @@ -392,47 +392,20 @@ def test_norm_preservation(self, shape) -> None:
rtol=1e-5,
)

@parameterized.product(
shape=[(5, 7), (33, 65), (127, 257)],
hyperball_radius=[0.5, 1.0, 2.0],
)
def test_hyperball_radius_rescales_params(self, shape, hyperball_radius) -> None:
"""Test that hyperball_radius kwarg rescales parameters to specified radius."""
test_param = nn.Parameter(torch.randn(shape, dtype=torch.float32, device=self.device))

opt = muon_hyperball.MuonHyperball(
[test_param],
lr=0.01,
hyperball_radius=hyperball_radius,
)

# After initialization, parameter should have the specified radius
torch.testing.assert_close(
test_param.norm(),
torch.tensor(hyperball_radius, device=self.device),
atol=1e-5,
rtol=1e-5,
)

# Run multiple steps with random gradients
for _ in range(5):
test_param.grad = torch.randn_like(test_param)
opt.step()
def test_zero_norm_raises_error(self) -> None:
test_param = nn.Parameter(torch.zeros((5, 7), device=self.device))
Comment on lines +395 to +396

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Missing docstring on test_zero_norm_raises_error — inconsistent with the surrounding tests which all carry docstrings describing their intent.

Suggested change
def test_zero_norm_raises_error(self) -> None:
test_param = nn.Parameter(torch.zeros((5, 7), device=self.device))
def test_zero_norm_raises_error(self) -> None:
"""Test that MuonHyperball raises ValueError for zero-norm parameters."""
test_param = nn.Parameter(torch.zeros((5, 7), device=self.device))

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@greptile, don't signal this if test name is self-explanary enough.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair point — test_zero_norm_raises_error is clear enough on its own. I'll remember not to flag missing docstrings on test methods when the test name already explains the intent.

Shall I create a rule: "Do not flag missing docstrings on test methods when the test name is sufficiently self-explanatory"?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair point — test_zero_norm_raises_error is clear enough on its own. I'll remember not to flag missing docstrings on test methods when the test name already explains the intent.

Shall I create a rule: "Do not flag missing docstrings on test methods when the test name is sufficiently self-explanatory"?

yes


# Norm should remain at hyperball_radius after each step
torch.testing.assert_close(
test_param.norm(),
torch.tensor(hyperball_radius, device=self.device),
atol=1e-5,
rtol=1e-5,
)
with self.assertRaises(ValueError):
muon_hyperball.MuonHyperball([test_param], lr=0.01, hyperball_radius=1.0)

def test_zero_norm_raises_error(self) -> None:
"""Test that MuonHyperball raises ValueError for zero-norm parameters."""
test_param = nn.Parameter(torch.zeros((5, 7), dtype=torch.float32, device=self.device))
def test_radius_mismatch_raises_error(self) -> None:
"""Test that MuonHyperball raises ValueError when a parameter's norm does not match the radius."""
test_param = nn.Parameter(torch.randn((5, 7), dtype=torch.float32, device=self.device))
# Pick a radius that differs from the parameter's actual Frobenius norm.
mismatched_radius = test_param.norm().item() + 1.0

with self.assertRaises(ValueError):
muon_hyperball.MuonHyperball([test_param], lr=0.01)
muon_hyperball.MuonHyperball([test_param], lr=0.01, hyperball_radius=mismatched_radius)


class PolarGradTest(parameterized.TestCase):
Expand Down
Loading