implement ring attetion with zigzag layout#39
Conversation
There was a problem hiding this comment.
Code Review
This pull request implements zigzag ring attention for context parallelism, which balances the causal workload across ranks by sharding the global sequence into non-contiguous chunks. The changes include updates to the model architectures (DeepSeek and Qwen) to handle the zigzag layout for rotary embeddings, modifications to the training setup for sequence length validation, and a complete rewrite of the ring attention operator to support asynchronous P2P communication and zigzag chunking. Feedback focuses on performance optimizations, specifically regarding the communication bandwidth of gradients in the backward pass and the redundant reconstruction of position IDs during the forward pass.
b19fa1c to
73d4d95
Compare
|
|
||
|
|
||
| out: [B, S, H, D], lse: [B, H, S] (float32). | ||
| @torch.compile(fullgraph=True) |
There was a problem hiding this comment.
Do we need the decorator here? Isn't the ring attention always run under an outside @torch.compile decorator?
There was a problem hiding this comment.
Maybe need to quickly check the kernel used here without the decorator.
There was a problem hiding this comment.
Isn't the ring attention always run under an outside
@torch.compile
No. We currently unwrap the torch.compile if ring attention is used.
Pith-Train/pithtrain/models/qwen3_30b_a3b.py
Lines 486 to 489 in c76a591
We apply the torch.compile(fullgraph=True) decorator, mainly to fuse the kernels in combine_partial.
We do see slight reduction with the latency,
- before this compile : fwd 3.004 ms, bwd 6.174 ms, step 9.178 ms
- after : fwd 2.467 ms, bwd 5.833 ms, step 8.300 ms
There was a problem hiding this comment.
Isn't the ring attention always run under an outside
@torch.compileNo. We currently unwrap the
torch.compileif ring attention is used.Pith-Train/pithtrain/models/qwen3_30b_a3b.py
Lines 486 to 489 in c76a591
@MasterJH5574 I think torch.compile can overall be applicable with ring attention. We should have a follow up PR to address this.
73d4d95 to
9bbf06c
Compare
Changes:
Qwen3-30b-a3b, cp4-s32k, one layer:
On the main branch, the workload with naive ring attention is imbalanced. At earlier ranks, NCCL stream took 83.1% of the time, waiting for the last rank to complete, which spends 95.2% on the default stream (i.e. compute), because it can attend to all previous positions.
On the current branch, the split between NCCL stream and the default stream is more even, with the transfer of KV blocks happening while the FA4 kernels run.