Skip to content

Add muown#236

Draft
skyw wants to merge 4 commits into
mainfrom
skyw/muown
Draft

Add muown#236
skyw wants to merge 4 commits into
mainfrom
skyw/muown

Conversation

@skyw

@skyw skyw commented Jun 24, 2026

Copy link
Copy Markdown
Contributor

No description provided.

skyw added 3 commits June 23, 2026 19:15
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
@copy-pr-bot

copy-pr-bot Bot commented Jun 24, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@greptile-apps

greptile-apps Bot commented Jun 24, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR adds Muown, a variant of Muon that applies per-row weight normalization: each 2D weight matrix is decomposed into a per-row magnitude g (updated by Adam) and a unit-direction v (updated by orthogonalized momentum), then recomposed. The implementation inherits from Muon and provides a clean override of _init_group and step.

  • New optimizer (muown.py): implements the weight-norm reparameterization with @torch.compile-decorated decomposition helper, fp32 state buffers, and decoupled weight-decay applied to g before recomposition. The initial v_norm and g are clamped at 1e-12 to avoid NaN on zero rows.
  • Test oracle (muown_reference.py): a self-contained reference copied from the upstream repo; used only as a comparison target in test_close_reference, which documents that the EMA vs. heavy-ball momentum difference cancels under scale-invariant Newton-Schulz.
  • Tests (test_muown.py): cover smoke runs, non-2D rejection, closure rejection, and the row-norm invariant ‖p‖_row == g; the reference-comparison test correctly justifies the momentum convention difference in its docstring.

Confidence Score: 5/5

Safe to merge; all tests cover the key invariants and the momentum-convention equivalence is correctly documented.

The core algorithm is mathematically sound and well-tested. There is one minor robustness gap: the per-step v_norm_new is stored without the clamp_min(1e-12) floor that guards the initial value, leaving a theoretical NaN path if a direction row collapses to zero. This is not reachable under normal training dynamics and does not affect any current test.

emerging_optimizers/orthogonalized_optimizers/muown.py — specifically the v_norm_new assignment at the end of the step loop.

Important Files Changed

Filename Overview
emerging_optimizers/orthogonalized_optimizers/muown.py Introduces the Muown optimizer; the init correctly clamps row norms but the per-step v_norm update is not clamped, creating a minor inconsistency.
tests/muown_reference.py Reference oracle used only in tests; no issues beyond the pre-existing zero-norm gap (not clamping w_norm at init), which won't surface with torch.randn inputs.
tests/test_muown.py Good test coverage: smoke, non-2D rejection, closure rejection, row-norm invariant, and reference comparison (with a clear docstring explaining the EMA vs. heavy-ball equivalence under scale-invariant Newton-Schulz).
emerging_optimizers/orthogonalized_optimizers/init.py Adds muown to the package exports, correctly placed alphabetically between muon_hyperball and orthogonalized_optimizer.

Sequence Diagram

%%{init: {'theme': 'neutral'}}%%
sequenceDiagram
    participant Caller
    participant Muown
    participant _weight_norm_decompose
    participant Adam as Adam update (calculate_adam_update)
    participant NS as scaled_orthogonalize_fn (Newton-Schulz)

    Caller->>Muown: step()
    Muown->>Muown: _init_group() — seed g, v_norm from row norms
    loop for each param p
        Muown->>_weight_norm_decompose: (W_fp32, grad_fp32, g, v_norm)
        _weight_norm_decompose-->>Muown: v, grad_g, grad_v
        Muown->>Muown: momentum_buffer.lerp_(grad_v, 1-m)  [EMA]
        Muown->>NS: momentum_buffer → direction_update
        NS-->>Muown: orthogonalized direction
        Muown->>Muown: "v_new = v - lr * direction_update"
        Muown->>Adam: (grad_g, m_g, v_g, betas, step) → magnitude_update
        Adam-->>Muown: magnitude_update (in-place updates m_g, v_g)
        Muown->>Muown: "g -= lr * magnitude_update"
        Muown->>Muown: "_apply_weight_decay_inplace(g) — g *= (1 - wd*lr)"
        Muown->>Muown: "v_norm_new = norm(v_new, row)"
        Muown->>Muown: "p <- g * (v_new / v_norm_new)"
        Muown->>Muown: "state[v_norm] = v_norm_new"
    end
    Muown-->>Caller: None
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
sequenceDiagram
    participant Caller
    participant Muown
    participant _weight_norm_decompose
    participant Adam as Adam update (calculate_adam_update)
    participant NS as scaled_orthogonalize_fn (Newton-Schulz)

    Caller->>Muown: step()
    Muown->>Muown: _init_group() — seed g, v_norm from row norms
    loop for each param p
        Muown->>_weight_norm_decompose: (W_fp32, grad_fp32, g, v_norm)
        _weight_norm_decompose-->>Muown: v, grad_g, grad_v
        Muown->>Muown: momentum_buffer.lerp_(grad_v, 1-m)  [EMA]
        Muown->>NS: momentum_buffer → direction_update
        NS-->>Muown: orthogonalized direction
        Muown->>Muown: "v_new = v - lr * direction_update"
        Muown->>Adam: (grad_g, m_g, v_g, betas, step) → magnitude_update
        Adam-->>Muown: magnitude_update (in-place updates m_g, v_g)
        Muown->>Muown: "g -= lr * magnitude_update"
        Muown->>Muown: "_apply_weight_decay_inplace(g) — g *= (1 - wd*lr)"
        Muown->>Muown: "v_norm_new = norm(v_new, row)"
        Muown->>Muown: "p <- g * (v_new / v_norm_new)"
        Muown->>Muown: "state[v_norm] = v_norm_new"
    end
    Muown-->>Caller: None
Loading

Reviews (2): Last reviewed commit: "fix some minor issues" | Re-trigger Greptile

Comment thread emerging_optimizers/orthogonalized_optimizers/muown.py Outdated
Comment thread tests/muown_reference.py
Comment thread emerging_optimizers/orthogonalized_optimizers/muown.py
@skyw

skyw commented Jun 24, 2026

Copy link
Copy Markdown
Contributor Author

ok, code is sub-optimal and not safe to merge, convert to draft.

@skyw skyw marked this pull request as draft June 24, 2026 23:02
Signed-off-by: Hao Wu <skyw@nvidia.com>
@skyw

skyw commented Jun 24, 2026

Copy link
Copy Markdown
Contributor Author

@greptileai

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant