Skip to content

Add stacked soap#235

Open
skyw wants to merge 5 commits into
mainfrom
skyw/stacked-soap
Open

Add stacked soap#235
skyw wants to merge 5 commits into
mainfrom
skyw/stacked-soap

Conversation

@skyw

@skyw skyw commented Jun 24, 2026

Copy link
Copy Markdown
Contributor

Save memory on grouped linear layers for MoE. Accuracy impact is yet to be tested.

skyw added 2 commits June 24, 2026 14:45
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
@skyw skyw requested a review from mkhona-nvidia June 24, 2026 22:31
@copy-pr-bot

copy-pr-bot Bot commented Jun 24, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@greptile-apps

greptile-apps Bot commented Jun 24, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR adds StackedSoap, a memory-saving subclass of SOAP for 3-D (batched) parameters such as grouped linear layers in MoE models. It transiently reshapes each 3-D parameter and its gradient into 2-D before delegating to the parent SOAP step, then copies the update back into the original storage.

  • _stack_2d / _unstack implement the bijective reshape: the batch dimension is merged into the smaller matrix edge ((b,m,n)(m,b·n) when n≤m, otherwise (b·m,n)). The data_ptr() comparison in the restore path correctly distinguishes view branches (no copy needed) from permute branches (independent buffer, copy required).
  • The try/finally guarantee ensures parameters are always restored to their 3-D shape even if super().step() raises (OOM, NaN check, etc.), and the tests cover smoke, roundtrip correctness, 2-D equivalence with plain SOAP, and 3-D equivalence with manually-stacked SOAP.

Confidence Score: 5/5

Safe to merge; the stacking logic and restore path are correct for all contiguous-input cases, and the try/finally prevents parameter corruption on exceptions.

The core swap-step-restore mechanism is sound: the data_ptr() comparison correctly distinguishes the view branch (n > m, contiguous) from the permute branch (n ≤ m, independent buffer), and the finally block handles both. The only gap is that 1-D bias parameters produce a cryptic unpack error rather than an informative one, but that does not affect the correctness of 2-D/3-D parameter optimization.

emerging_optimizers/soap/soap.py — specifically the shape-validation gap at step entry for 1-D parameters.

Important Files Changed

Filename Overview
emerging_optimizers/soap/soap.py Adds _stack_2d, _unstack helpers and StackedSoap subclass; the try/finally restore is correct and the data_ptr comparison correctly detects view vs copy branches. One usability gap: 1-D parameters (biases) crash with a cryptic ValueError instead of an actionable error.
tests/test_soap.py Adds StackedSoapTest with smoke, roundtrip, 2D-vs-SOAP equivalence, and 3D-vs-stacked-SOAP equivalence tests; covers both stacking branches and uses exact tolerance where appropriate.

Sequence Diagram

%%{init: {'theme': 'neutral'}}%%
sequenceDiagram
    participant Caller
    participant StackedSoap
    participant _stack_2d
    participant SOAP_step as SOAP.step()
    participant _unstack

    Caller->>StackedSoap: step()
    StackedSoap->>StackedSoap: enter try block
    loop for each param p with grad
        StackedSoap->>_stack_2d: _stack_2d(p.data)
        _stack_2d-->>StackedSoap: stacked_data (2D view or new buffer)
        StackedSoap->>_stack_2d: _stack_2d(p.grad)
        _stack_2d-->>StackedSoap: stacked_grad (2D view or new buffer)
        StackedSoap->>StackedSoap: swap p.data / p.grad to stacked 2D
        StackedSoap->>StackedSoap: saved.append((p, orig_data, orig_grad))
    end
    StackedSoap->>SOAP_step: super().step()
    Note over SOAP_step: _init_group (lazy state init to 2D shape), Kronecker factor update, Eigenbasis update, Adam update, p.add_(precond_update) in-place
    SOAP_step-->>StackedSoap: done
    StackedSoap->>StackedSoap: enter finally block
    loop for each (p, orig_data, orig_grad) in saved
        StackedSoap->>StackedSoap: "stacked = p.data"
        StackedSoap->>StackedSoap: "restore p.data = orig_data, p.grad = orig_grad"
        alt "stacked.data_ptr() != orig_data.data_ptr() (permute branch)"
            StackedSoap->>_unstack: _unstack(stacked, orig_shape)
            _unstack-->>StackedSoap: restored 3D tensor
            StackedSoap->>StackedSoap: orig_data.copy_(restored)
        else view branch - in-place update already in orig_data
            StackedSoap->>StackedSoap: skip copy
        end
    end
    StackedSoap-->>Caller: None
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
sequenceDiagram
    participant Caller
    participant StackedSoap
    participant _stack_2d
    participant SOAP_step as SOAP.step()
    participant _unstack

    Caller->>StackedSoap: step()
    StackedSoap->>StackedSoap: enter try block
    loop for each param p with grad
        StackedSoap->>_stack_2d: _stack_2d(p.data)
        _stack_2d-->>StackedSoap: stacked_data (2D view or new buffer)
        StackedSoap->>_stack_2d: _stack_2d(p.grad)
        _stack_2d-->>StackedSoap: stacked_grad (2D view or new buffer)
        StackedSoap->>StackedSoap: swap p.data / p.grad to stacked 2D
        StackedSoap->>StackedSoap: saved.append((p, orig_data, orig_grad))
    end
    StackedSoap->>SOAP_step: super().step()
    Note over SOAP_step: _init_group (lazy state init to 2D shape), Kronecker factor update, Eigenbasis update, Adam update, p.add_(precond_update) in-place
    SOAP_step-->>StackedSoap: done
    StackedSoap->>StackedSoap: enter finally block
    loop for each (p, orig_data, orig_grad) in saved
        StackedSoap->>StackedSoap: "stacked = p.data"
        StackedSoap->>StackedSoap: "restore p.data = orig_data, p.grad = orig_grad"
        alt "stacked.data_ptr() != orig_data.data_ptr() (permute branch)"
            StackedSoap->>_unstack: _unstack(stacked, orig_shape)
            _unstack-->>StackedSoap: restored 3D tensor
            StackedSoap->>StackedSoap: orig_data.copy_(restored)
        else view branch - in-place update already in orig_data
            StackedSoap->>StackedSoap: skip copy
        end
    end
    StackedSoap-->>Caller: None
Loading

Reviews (4): Last reviewed commit: "qol improvement" | Re-trigger Greptile

Comment thread emerging_optimizers/soap/soap.py Outdated
Comment thread emerging_optimizers/soap/soap.py Outdated
Signed-off-by: Hao Wu <skyw@nvidia.com>
@skyw

skyw commented Jun 24, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test 283d7ed

@github-actions

github-actions Bot commented Jun 24, 2026

Copy link
Copy Markdown

Test Results

   77 files  ± 0    149 suites  +2   1m 42s ⏱️ +3s
1 156 tests +12  1 156 ✅ +12  0 💤 ±0  0 ❌ ±0 
2 694 runs  +24  2 694 ✅ +24  0 💤 ±0  0 ❌ ±0 

Results for commit fdb8c27. ± Comparison against base commit 93376d9.

♻️ This comment has been updated with latest results.

@codecov

codecov Bot commented Jun 24, 2026

Copy link
Copy Markdown

Codecov Report

✅ All modified and coverable lines are covered by tests.

📢 Thoughts on this report? Let us know!

skyw added 2 commits June 24, 2026 20:58
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
@skyw

skyw commented Jun 25, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test fdb8c27

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant