Add right polargrad function#234
Conversation
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Greptile SummaryThis PR introduces
Confidence Score: 5/5The change is safe to merge; it adds a self-contained new function and reorganizes existing tests without modifying any existing optimizer logic. The core math in tests/test_polargrad.py — the Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["right_polargrad_orth_fn(grad)"] --> B{"center_rows?"}
B -- Yes --> C["m = grad − mean(grad, dim=0)"]
B -- No --> D["m = grad.to(float32)"]
C --> E["eigh_with_fallback(mᵀm)\n→ eigvals, eigvecs (descending)"]
D --> E
E --> F["eigvals.clamp_min_(eps)"]
F --> G["right_gram_inv_sqrt = V · diag(λ⁻½) · Vᵀ"]
G --> H["u = m @ right_gram_inv_sqrt\n(right polar factor)"]
F --> I["nuclear_norm = Σ √λᵢ"]
H --> J["update = u · ‖G‖_*^α · extra_scale_factor"]
I --> J
J --> K{"center_rows?"}
K -- Yes --> L["update = update − mean(update, dim=0)"]
K -- No --> M["return update.to(grad.dtype)"]
L --> M
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
A["right_polargrad_orth_fn(grad)"] --> B{"center_rows?"}
B -- Yes --> C["m = grad − mean(grad, dim=0)"]
B -- No --> D["m = grad.to(float32)"]
C --> E["eigh_with_fallback(mᵀm)\n→ eigvals, eigvecs (descending)"]
D --> E
E --> F["eigvals.clamp_min_(eps)"]
F --> G["right_gram_inv_sqrt = V · diag(λ⁻½) · Vᵀ"]
G --> H["u = m @ right_gram_inv_sqrt\n(right polar factor)"]
F --> I["nuclear_norm = Σ √λᵢ"]
H --> J["update = u · ‖G‖_*^α · extra_scale_factor"]
I --> J
J --> K{"center_rows?"}
K -- Yes --> L["update = update − mean(update, dim=0)"]
K -- No --> M["return update.to(grad.dtype)"]
L --> M
Reviews (2): Last reviewed commit: "minor improvement" | Re-trigger Greptile |
Signed-off-by: Hao Wu <skyw@nvidia.com>
|
/ok to test 33aef2a |
Codecov Report✅ All modified and coverable lines are covered by tests. 📢 Thoughts on this report? Let us know! |
Add orthogonalization function and example of construction right polargrad optimizer.