EXPERIMENTAL tensor parallel REKLS#175
Conversation
Greptile SummaryThis PR adds an experimental tensor-parallel variant of REKLS (
Confidence Score: 5/5The core optimizer logic is mathematically correct and matches the non-distributed REKLS update; the PR is explicitly flagged experimental and not intended for immediate merge. The sharding scheme, gather/write-back round-trip, divisibility validation, and replicated fallback all check out. The only finding is a minor implicit-default in No files require special attention. Important Files Changed
Sequence Diagram%%{init: {'theme': 'neutral'}}%%
sequenceDiagram
participant R0 as Rank 0 (TpRekls.step)
participant Rn as Rank N (TpRekls.step)
participant DG as dist (all_gather)
Note over R0,Rn: partition_dim in {0, 1} path
R0->>DG: all_gather(local_grad shards)
Rn->>DG: all_gather(local_grad shards)
DG-->>R0: full_grad (m x n)
DG-->>Rn: full_grad (m x n)
R0->>DG: "all_gather(L shards, dim=0)"
Rn->>DG: "all_gather(L shards, dim=0)"
DG-->>R0: full_L (m x m)
DG-->>Rn: full_L (m x m)
R0->>DG: "all_gather(R shards, dim=0)"
Rn->>DG: "all_gather(R shards, dim=0)"
DG-->>R0: full_R (n x n)
DG-->>Rn: full_R (n x n)
Note over R0,Rn: Each rank independently (identical computation)
R0->>R0: eigh(full_L, full_R) → pre_eigenbasis
R0->>R0: update_kronecker_factors_kl_shampoo(full_L, full_R, full_grad)
R0->>R0: "state[L] = full_L.chunk(tp_size)[rank]"
R0->>R0: "state[R] = full_R.chunk(tp_size)[rank]"
R0->>R0: update_eigenbasis_and_exp_avgs → post_eigenbasis, exp_avg
R0->>R0: precondition(full_grad) → Adam update → unprecondition
R0->>R0: "p += local_shard(full_precond_update, dim=partition_dim)"
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
sequenceDiagram
participant R0 as Rank 0 (TpRekls.step)
participant Rn as Rank N (TpRekls.step)
participant DG as dist (all_gather)
Note over R0,Rn: partition_dim in {0, 1} path
R0->>DG: all_gather(local_grad shards)
Rn->>DG: all_gather(local_grad shards)
DG-->>R0: full_grad (m x n)
DG-->>Rn: full_grad (m x n)
R0->>DG: "all_gather(L shards, dim=0)"
Rn->>DG: "all_gather(L shards, dim=0)"
DG-->>R0: full_L (m x m)
DG-->>Rn: full_L (m x m)
R0->>DG: "all_gather(R shards, dim=0)"
Rn->>DG: "all_gather(R shards, dim=0)"
DG-->>R0: full_R (n x n)
DG-->>Rn: full_R (n x n)
Note over R0,Rn: Each rank independently (identical computation)
R0->>R0: eigh(full_L, full_R) → pre_eigenbasis
R0->>R0: update_kronecker_factors_kl_shampoo(full_L, full_R, full_grad)
R0->>R0: "state[L] = full_L.chunk(tp_size)[rank]"
R0->>R0: "state[R] = full_R.chunk(tp_size)[rank]"
R0->>R0: update_eigenbasis_and_exp_avgs → post_eigenbasis, exp_avg
R0->>R0: precondition(full_grad) → Adam update → unprecondition
R0->>R0: "p += local_shard(full_precond_update, dim=partition_dim)"
Reviews (14): Last reviewed commit: "roll back laprop experiment" | Re-trigger Greptile |
|
@greptile check closure contract again |
|
@greptile overload can't be dropped because it is in upstream pytorch. |
|
/ok to test 6eee42f |
|
That's a fair point. The The actual runtime behavior is still made clear by the
|
|
@mkhona-nvidia this actually turned out to be a very mergeable limited scope version, take a quick look. |
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Contain the experimental tensor-parallel REKLS code in its own module: move the TpRekls class and the all_gather_grad_and_kronecker_factors_tp helper out of rekls.py / tp_utils.py into tp_rekls.py, and update tests. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
|
/ok to test 2b0d10a |
Signed-off-by: Hao Wu <skyw@nvidia.com>
|
/ok to test 761bf96 |
For experimental purpose only, DONOT merge yet.
UPDATE: The experiments turned out to be very successful. But double eigh is very slow for production use, defer merging to at least next release.