Skip to content

Add right polargrad function#234

Open
skyw wants to merge 3 commits into
mainfrom
skyw/right_polargrad_fn
Open

Add right polargrad function#234
skyw wants to merge 3 commits into
mainfrom
skyw/right_polargrad_fn

Conversation

@skyw

@skyw skyw commented Jun 24, 2026

Copy link
Copy Markdown
Contributor

Add orthogonalization function and example of construction right polargrad optimizer.

skyw added 2 commits June 24, 2026 11:45
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
@copy-pr-bot

copy-pr-bot Bot commented Jun 24, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@greptile-apps

greptile-apps Bot commented Jun 24, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR introduces right_polargrad_orth_fn, a standalone function for one-sided (right) polar orthogonalization of tall matrices, intended for use with embedding and LM-head weights where the full two-sided polar factor is impractical. PolarGrad tests are reorganized from test_orthogonalized_optimizer.py into a dedicated test_polargrad.py.

  • Adds right_polargrad_orth_fn to polargrad.py: computes G (G^T G)^{-1/2} via eigh_with_fallback, scales by ||G||_*^alpha, and optionally zero-centers columns before and after the update.
  • Exposes the function in __all__ and adds an autofunction directive to the API docs.
  • Moves all PolarGrad-related tests to test_polargrad.py and adds a new RightPolarGradOrthFnTest suite covering right-orthogonal equivariance and optimizer integration.

Confidence Score: 5/5

The change is safe to merge; it adds a self-contained new function and reorganizes existing tests without modifying any existing optimizer logic.

The core math in right_polargrad_orth_fn is correct — the Gram-matrix eigendecomposition, inverse square root construction, nuclear-norm scaling, and optional column centering all check out. The function is additive (no existing code paths are touched), and the previous docstring inaccuracies flagged in earlier review rounds have been fixed.

tests/test_polargrad.py — the alpha parameter of right_polargrad_orth_fn has no assertion covering its scaling effect.

Important Files Changed

Filename Overview
emerging_optimizers/orthogonalized_optimizers/polargrad.py Adds right_polargrad_orth_fn: mathematically correct right-polar orthogonalization via eigendecomposition of the Gram matrix, with optional column centering and nuclear-norm scaling. Previous docstring issues are addressed in this revision.
tests/test_polargrad.py New standalone test file for PolarGrad and RightPolarGradOrthFn. Covers right-orthogonal equivariance and basic smoke/integration tests, but alpha != 1.0 scaling is not directly verified by any assertion.
tests/test_orthogonalized_optimizer.py Removes PolarGrad import and PolarGradTest class, moved to the dedicated test_polargrad.py; no coverage is lost.
docs/apidocs/orthogonalized-optimizers.md Adds autofunction directive for right_polargrad_orth_fn to the API docs.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["right_polargrad_orth_fn(grad)"] --> B{"center_rows?"}
    B -- Yes --> C["m = grad − mean(grad, dim=0)"]
    B -- No --> D["m = grad.to(float32)"]
    C --> E["eigh_with_fallback(mᵀm)\n→ eigvals, eigvecs (descending)"]
    D --> E
    E --> F["eigvals.clamp_min_(eps)"]
    F --> G["right_gram_inv_sqrt = V · diag(λ⁻½) · Vᵀ"]
    G --> H["u = m @ right_gram_inv_sqrt\n(right polar factor)"]
    F --> I["nuclear_norm = Σ √λᵢ"]
    H --> J["update = u · ‖G‖_*^α · extra_scale_factor"]
    I --> J
    J --> K{"center_rows?"}
    K -- Yes --> L["update = update − mean(update, dim=0)"]
    K -- No --> M["return update.to(grad.dtype)"]
    L --> M
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
    A["right_polargrad_orth_fn(grad)"] --> B{"center_rows?"}
    B -- Yes --> C["m = grad − mean(grad, dim=0)"]
    B -- No --> D["m = grad.to(float32)"]
    C --> E["eigh_with_fallback(mᵀm)\n→ eigvals, eigvecs (descending)"]
    D --> E
    E --> F["eigvals.clamp_min_(eps)"]
    F --> G["right_gram_inv_sqrt = V · diag(λ⁻½) · Vᵀ"]
    G --> H["u = m @ right_gram_inv_sqrt\n(right polar factor)"]
    F --> I["nuclear_norm = Σ √λᵢ"]
    H --> J["update = u · ‖G‖_*^α · extra_scale_factor"]
    I --> J
    J --> K{"center_rows?"}
    K -- Yes --> L["update = update − mean(update, dim=0)"]
    K -- No --> M["return update.to(grad.dtype)"]
    L --> M
Loading

Reviews (2): Last reviewed commit: "minor improvement" | Re-trigger Greptile

Comment thread tests/test_polargrad.py
Comment thread emerging_optimizers/orthogonalized_optimizers/polargrad.py Outdated
Comment thread emerging_optimizers/orthogonalized_optimizers/polargrad.py Outdated
Comment thread emerging_optimizers/orthogonalized_optimizers/polargrad.py Outdated
Signed-off-by: Hao Wu <skyw@nvidia.com>
@skyw

skyw commented Jun 24, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test 33aef2a

@github-actions

Copy link
Copy Markdown

Test Results

   79 files  + 2    149 suites  +2   1m 52s ⏱️ +9s
1 149 tests + 5  1 149 ✅ + 5  0 💤 ±0  0 ❌ ±0 
2 680 runs  +10  2 680 ✅ +10  0 💤 ±0  0 ❌ ±0 

Results for commit 33aef2a. ± Comparison against base commit 93376d9.

@codecov

codecov Bot commented Jun 24, 2026

Copy link
Copy Markdown

Codecov Report

✅ All modified and coverable lines are covered by tests.

📢 Thoughts on this report? Let us know!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant