Skip to content

EXPERIMENTAL tensor parallel REKLS#175

Merged
skyw merged 19 commits into
mainfrom
skyw/tp_rekls_exp
Jun 23, 2026
Merged

EXPERIMENTAL tensor parallel REKLS#175
skyw merged 19 commits into
mainfrom
skyw/tp_rekls_exp

Conversation

@skyw

@skyw skyw commented May 6, 2026

Copy link
Copy Markdown
Contributor

For experimental purpose only, DONOT merge yet.

UPDATE: The experiments turned out to be very successful. But double eigh is very slow for production use, defer merging to at least next release.

@copy-pr-bot

copy-pr-bot Bot commented May 6, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@greptile-apps

greptile-apps Bot commented May 6, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR adds an experimental tensor-parallel variant of REKLS (TpRekls) that all-gathers sharded gradients and Kronecker factors on every step, runs the full KL-Shampoo update on the gathered tensors (double eigh), then writes the local shard back to optimizer state. The PR explicitly marks itself as not ready to merge due to the per-step communication overhead.

  • tp_rekls.py: New TpRekls optimizer — all-gather of grad + L + R (3 collectives per step), divisibility check on both dimensions, correct in-place/copy-back split between replicated and TP paths, and bit-exact agreement with non-distributed REKLS verified by the new end-to-end test.
  • soap_utils.py: Adds get_eigenbasis_svd as a new standalone utility (exported in __all__) for completeness; tp_rekls.py itself uses get_eigenbasis_eigh, not svd.
  • CI / tests: CI loop is generalized to pick up all test_distributed_*_cpu.py files; a dedicated nproc_per_node=1 run validates the single-rank degenerate path.

Confidence Score: 5/5

The core optimizer logic is mathematically correct and matches the non-distributed REKLS update; the PR is explicitly flagged experimental and not intended for immediate merge.

The sharding scheme, gather/write-back round-trip, divisibility validation, and replicated fallback all check out. The only finding is a minor implicit-default in torch.linalg.svd inside get_eigenbasis_svd, which has no effect on current inputs (all Kronecker factors are square). No correctness defects were found in the changed code paths.

No files require special attention. soap_utils.py's new get_eigenbasis_svd is the only spot worth a second look before the function is used more broadly outside square matrices.

Important Files Changed

Filename Overview
emerging_optimizers/soap/tp_rekls.py New TpRekls optimizer: all-gather-based TP variant of REKLS with correct sharding, gather/write-back, and divisibility checks; 3 all-gathers per step (grad + L + R) are explicitly noted as the production bottleneck.
emerging_optimizers/soap/soap_utils.py Adds get_eigenbasis_svd; function is correct for square PSD Kronecker factors but calls torch.linalg.svd without explicit full_matrices=False, leaving an implicit default; not used by tp_rekls.py itself.
emerging_optimizers/soap/soap.py Docstring-only improvement to update_kronecker_factors_kl_shampoo: adds LaTeX math and renames eigenval_exp → eigval_exp in the Args block.
emerging_optimizers/utils/init.py Adds get_pg_size / get_pg_rank helpers with safe fallbacks (return 1 / 0 when distributed is not initialized or group is None); exported in all.
tests/ci/L0_Tests_CPU.sh CI now loops over all test_distributed_*_cpu.py files for n=8 and n=4, plus an explicit n=1 run for TpRekls single-process sanity check.
tests/test_distributed_rekls_cpu.py End-to-end test verifying TpRekls produces bit-identical updates to non-distributed REKLS across 5 steps for mixed partition_dim configurations; uses atol=0,rtol=0 with highest matmul precision.
tests/test_distributed_soap_utils_cpu.py Unit tests for all_gather_grad_and_kronecker_factors_tp validating exact round-trip for both partition_dim=0 and partition_dim=1 with integer-valued tensors for bit-exact comparison.

Sequence Diagram

%%{init: {'theme': 'neutral'}}%%
sequenceDiagram
    participant R0 as Rank 0 (TpRekls.step)
    participant Rn as Rank N (TpRekls.step)
    participant DG as dist (all_gather)

    Note over R0,Rn: partition_dim in {0, 1} path
    R0->>DG: all_gather(local_grad shards)
    Rn->>DG: all_gather(local_grad shards)
    DG-->>R0: full_grad (m x n)
    DG-->>Rn: full_grad (m x n)

    R0->>DG: "all_gather(L shards, dim=0)"
    Rn->>DG: "all_gather(L shards, dim=0)"
    DG-->>R0: full_L (m x m)
    DG-->>Rn: full_L (m x m)

    R0->>DG: "all_gather(R shards, dim=0)"
    Rn->>DG: "all_gather(R shards, dim=0)"
    DG-->>R0: full_R (n x n)
    DG-->>Rn: full_R (n x n)

    Note over R0,Rn: Each rank independently (identical computation)
    R0->>R0: eigh(full_L, full_R) → pre_eigenbasis
    R0->>R0: update_kronecker_factors_kl_shampoo(full_L, full_R, full_grad)
    R0->>R0: "state[L] = full_L.chunk(tp_size)[rank]"
    R0->>R0: "state[R] = full_R.chunk(tp_size)[rank]"
    R0->>R0: update_eigenbasis_and_exp_avgs → post_eigenbasis, exp_avg
    R0->>R0: precondition(full_grad) → Adam update → unprecondition
    R0->>R0: "p += local_shard(full_precond_update, dim=partition_dim)"
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
sequenceDiagram
    participant R0 as Rank 0 (TpRekls.step)
    participant Rn as Rank N (TpRekls.step)
    participant DG as dist (all_gather)

    Note over R0,Rn: partition_dim in {0, 1} path
    R0->>DG: all_gather(local_grad shards)
    Rn->>DG: all_gather(local_grad shards)
    DG-->>R0: full_grad (m x n)
    DG-->>Rn: full_grad (m x n)

    R0->>DG: "all_gather(L shards, dim=0)"
    Rn->>DG: "all_gather(L shards, dim=0)"
    DG-->>R0: full_L (m x m)
    DG-->>Rn: full_L (m x m)

    R0->>DG: "all_gather(R shards, dim=0)"
    Rn->>DG: "all_gather(R shards, dim=0)"
    DG-->>R0: full_R (n x n)
    DG-->>Rn: full_R (n x n)

    Note over R0,Rn: Each rank independently (identical computation)
    R0->>R0: eigh(full_L, full_R) → pre_eigenbasis
    R0->>R0: update_kronecker_factors_kl_shampoo(full_L, full_R, full_grad)
    R0->>R0: "state[L] = full_L.chunk(tp_size)[rank]"
    R0->>R0: "state[R] = full_R.chunk(tp_size)[rank]"
    R0->>R0: update_eigenbasis_and_exp_avgs → post_eigenbasis, exp_avg
    R0->>R0: precondition(full_grad) → Adam update → unprecondition
    R0->>R0: "p += local_shard(full_precond_update, dim=partition_dim)"
Loading

Reviews (14): Last reviewed commit: "roll back laprop experiment" | Re-trigger Greptile

Comment thread emerging_optimizers/soap/rekls.py Outdated
Comment thread tests/test_distributed_rekls_cpu.py Outdated
Comment thread tests/test_distributed_soap_utils_cpu.py Outdated
@skyw skyw marked this pull request as draft May 6, 2026 20:32
@skyw skyw marked this pull request as ready for review May 6, 2026 20:40
@skyw skyw marked this pull request as draft May 6, 2026 20:40
@skyw

skyw commented May 6, 2026

Copy link
Copy Markdown
Contributor Author

@greptile check closure contract again

@skyw skyw marked this pull request as ready for review May 6, 2026 22:33
Comment thread emerging_optimizers/soap/rekls.py Outdated
@skyw

skyw commented May 6, 2026

Copy link
Copy Markdown
Contributor Author

@greptile overload can't be dropped because it is in upstream pytorch.

@skyw skyw force-pushed the skyw/tp_rekls_exp branch from e61f0f2 to 6eee42f Compare May 6, 2026 22:54
@skyw

skyw commented May 6, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test 6eee42f

@greptile-apps

greptile-apps Bot commented May 6, 2026

Copy link
Copy Markdown
Contributor

That's a fair point. The @overload stubs are guarded under if TYPE_CHECKING: so they're only visible to type checkers (mypy/pyright), not at runtime — and since TpRekls inherits from optim.Optimizer, which declares those two overloads in its upstream type stubs, carrying them here is necessary to satisfy the @override contract and avoid type errors.

The actual runtime behavior is still made clear by the assert closure is None, "No support for closure" at the top of the implementation, so there's no ambiguity for callers at runtime. Acknowledged — no action needed here.

Tip: You can give me more context with [.greptile/rules.md](https://www.greptile.com/docs/code-review/custom-standards) so I get this right next time.

@skyw

skyw commented May 6, 2026

Copy link
Copy Markdown
Contributor Author

@mkhona-nvidia this actually turned out to be a very mergeable limited scope version, take a quick look.

@github-actions

github-actions Bot commented May 6, 2026

Copy link
Copy Markdown

Test Results

   77 files  + 25    147 suites  +25   1m 36s ⏱️ +5s
1 152 tests + 13  1 152 ✅ + 13  0 💤 ±0  0 ❌ ±0 
2 686 runs  +157  2 686 ✅ +157  0 💤 ±0  0 ❌ ±0 

Results for commit 761bf96. ± Comparison against base commit 84ddd02.

♻️ This comment has been updated with latest results.

@codecov

codecov Bot commented May 6, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 84.95575% with 17 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
emerging_optimizers/soap/tp_rekls.py 91.83% 4 Missing and 4 partials ⚠️
emerging_optimizers/soap/soap_utils.py 16.66% 5 Missing ⚠️
emerging_optimizers/utils/__init__.py 55.55% 2 Missing and 2 partials ⚠️

📢 Thoughts on this report? Let us know!

@skyw skyw marked this pull request as draft May 26, 2026 20:32
skyw added 11 commits June 23, 2026 13:34
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
skyw and others added 6 commits June 23, 2026 13:34
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Contain the experimental tensor-parallel REKLS code in its own module:
move the TpRekls class and the all_gather_grad_and_kronecker_factors_tp
helper out of rekls.py / tp_utils.py into tp_rekls.py, and update tests.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
@skyw skyw force-pushed the skyw/tp_rekls_exp branch from 208720c to d4d374a Compare June 23, 2026 21:06
@skyw skyw marked this pull request as ready for review June 23, 2026 21:28
@skyw

skyw commented Jun 23, 2026

Copy link
Copy Markdown
Contributor Author

@greptileai

Signed-off-by: Hao Wu <skyw@nvidia.com>
@skyw

skyw commented Jun 23, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test 2b0d10a

Comment thread emerging_optimizers/soap/tp_rekls.py Outdated
Signed-off-by: Hao Wu <skyw@nvidia.com>
@skyw

skyw commented Jun 23, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test 761bf96

@skyw skyw merged commit 93376d9 into main Jun 23, 2026
25 checks passed
@skyw skyw deleted the skyw/tp_rekls_exp branch June 23, 2026 23:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants