Skip to content

Improve hyperball implementation#233

Merged
skyw merged 6 commits into
mainfrom
skyw/improve_hyperball
Jun 23, 2026
Merged

Improve hyperball implementation#233
skyw merged 6 commits into
mainfrom
skyw/improve_hyperball

Conversation

@skyw

@skyw skyw commented Jun 23, 2026

Copy link
Copy Markdown
Contributor

Updated naming and interface.

eps handling needs to be revisited for the entire repo.

mkhona-nvidia and others added 6 commits June 23, 2026 10:19
Signed-off-by: mikail <mkhona@nvidia.com>
Signed-off-by: mikail <mkhona@nvidia.com>
Signed-off-by: mikail <mkhona@nvidia.com>
Signed-off-by: mikail <mkhona@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
@skyw skyw requested a review from mkhona-nvidia June 23, 2026 17:57
@copy-pr-bot

copy-pr-bot Bot commented Jun 23, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@skyw skyw mentioned this pull request Jun 23, 2026
@greptile-apps

greptile-apps Bot commented Jun 23, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR simplifies MuonHyperball by making hyperball_radius a required constructor argument and removing the previous auto-rescaling behaviour; the initializer now validates that every parameter's Frobenius norm already matches the specified radius (within rtol=1e-5), raising a ValueError if not.

  • hyperball_radius is now mandatory and all parameters must already lie on the target hyperball at construction time — no silent in-place rescaling occurs.
  • Per-parameter hyperball_R state is eliminated; pre_weight_update_fn_inplace and post_weight_update_fn_inplace read self.hyperball_radius directly.
  • Tests are updated to pass hyperball_radius explicitly; the old rescaling test is replaced with a new radius-mismatch test.

Confidence Score: 4/5

Safe to merge; the algorithm correctness is unchanged and the new strict-validation constructor is a clear improvement over silent rescaling.

The core optimizer logic is correct — removing per-param state and using self.hyperball_radius directly simplifies the code without changing behaviour. The main open question is the hyperball_eps default drop from 1e-8 to 1e-15, which the PR description explicitly flags for follow-up; for float32, this weakens the clamp_min floor in the update routines. The test suite covers norm preservation, zero-norm rejection, and the new mismatch error, but the missing docstring in test_zero_norm_raises_error is a minor inconsistency.

emerging_optimizers/orthogonalized_optimizers/muon_hyperball.py — specifically the hyperball_eps default and its use in clamp_min calls.

Important Files Changed

Filename Overview
emerging_optimizers/orthogonalized_optimizers/muon_hyperball.py Makes hyperball_radius a required argument, removes auto-rescaling on init, adds strict norm-match validation via torch.isclose, and simplifies per-step logic by eliminating per-param state storage; the hyperball_eps default drop from 1e-8 to 1e-15 weakens float32 clamp protection.
tests/test_orthogonalized_optimizer.py Updates tests to pass hyperball_radius explicitly, removes the auto-rescale test, and adds a new test_radius_mismatch_raises_error test; test_zero_norm_raises_error is missing its docstring.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["__init__(hyperball_radius=R)"] --> B["Set self.hyperball_radius = R\nSet self.hyperball_eps = eps"]
    B --> C["super().__init__()"]
    C --> D["For each param p"]
    D --> E{"p.norm() ≤ hyperball_eps?"}
    E -- Yes --> F["raise ValueError\n(near-zero norm)"]
    E -- No --> G{"isclose(p.norm(), R,\natol=0, rtol=1e-5)?"}
    G -- No --> H["raise ValueError\n(norm mismatch)"]
    G -- Yes --> I["Initialization OK"]
    I --> J["optimizer.step()"]
    J --> K["pre_weight_update_fn_inplace(p, update)"]
    K --> L["update_norm = update.norm()\n.clamp_min(hyperball_eps)"]
    L --> M["update *= R / update_norm"]
    M --> N["Muon weight update:\np = p - lr × update"]
    N --> O["post_weight_update_fn_inplace(p)"]
    O --> P["p_norm = p.norm()\n.clamp_min(hyperball_eps)"]
    P --> Q["p *= R / p_norm\n(project back to sphere)"]
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
    A["__init__(hyperball_radius=R)"] --> B["Set self.hyperball_radius = R\nSet self.hyperball_eps = eps"]
    B --> C["super().__init__()"]
    C --> D["For each param p"]
    D --> E{"p.norm() ≤ hyperball_eps?"}
    E -- Yes --> F["raise ValueError\n(near-zero norm)"]
    E -- No --> G{"isclose(p.norm(), R,\natol=0, rtol=1e-5)?"}
    G -- No --> H["raise ValueError\n(norm mismatch)"]
    G -- Yes --> I["Initialization OK"]
    I --> J["optimizer.step()"]
    J --> K["pre_weight_update_fn_inplace(p, update)"]
    K --> L["update_norm = update.norm()\n.clamp_min(hyperball_eps)"]
    L --> M["update *= R / update_norm"]
    M --> N["Muon weight update:\np = p - lr × update"]
    N --> O["post_weight_update_fn_inplace(p)"]
    O --> P["p_norm = p.norm()\n.clamp_min(hyperball_eps)"]
    P --> Q["p *= R / p_norm\n(project back to sphere)"]
Loading

Reviews (1): Last reviewed commit: "add test for raises" | Re-trigger Greptile

hyperball_eps: float = 1e-8,
hyperball_radius: float | None = None,
hyperball_radius: float,
hyperball_eps: float = 1e-15,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 hyperball_eps default weakens float32 numerical protection

The default dropped from 1e-8 to 1e-15. clamp_min(self.hyperball_eps) is used in both pre_weight_update_fn_inplace and post_weight_update_fn_inplace to guard against division by near-zero norms. Float32 machine epsilon is ~1.19e-7, so a gradient norm like 1e-10 is well above 1e-15 and would not be clamped, while the old 1e-8 would have clamped it. The result is that a near-zero gradient norm can now produce a scaling factor (hyperball_radius / update_norm) that is up to 10^7× larger than before, potentially creating very large intermediate values during the weight-update step before post_weight_update_fn_inplace projects back to the sphere. The PR description notes eps handling is being revisited across the repo — worth tracking this alongside that broader change.

Comment on lines +395 to +396
def test_zero_norm_raises_error(self) -> None:
test_param = nn.Parameter(torch.zeros((5, 7), device=self.device))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Missing docstring on test_zero_norm_raises_error — inconsistent with the surrounding tests which all carry docstrings describing their intent.

Suggested change
def test_zero_norm_raises_error(self) -> None:
test_param = nn.Parameter(torch.zeros((5, 7), device=self.device))
def test_zero_norm_raises_error(self) -> None:
"""Test that MuonHyperball raises ValueError for zero-norm parameters."""
test_param = nn.Parameter(torch.zeros((5, 7), device=self.device))

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@greptile, don't signal this if test name is self-explanary enough.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair point — test_zero_norm_raises_error is clear enough on its own. I'll remember not to flag missing docstrings on test methods when the test name already explains the intent.

Shall I create a rule: "Do not flag missing docstrings on test methods when the test name is sufficiently self-explanatory"?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair point — test_zero_norm_raises_error is clear enough on its own. I'll remember not to flag missing docstrings on test methods when the test name already explains the intent.

Shall I create a rule: "Do not flag missing docstrings on test methods when the test name is sufficiently self-explanatory"?

yes

@skyw

skyw commented Jun 23, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test f32adcf

@skyw skyw enabled auto-merge (squash) June 23, 2026 19:56
@skyw skyw merged commit ffc2e04 into main Jun 23, 2026
24 checks passed
@skyw skyw deleted the skyw/improve_hyperball branch June 23, 2026 21:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants