Skip to content

implement ring attetion with zigzag layout#39

Merged
MasterJH5574 merged 1 commit into
mlc-ai:mainfrom
haok1402:0513-zigzag-attention
May 14, 2026
Merged

implement ring attetion with zigzag layout#39
MasterJH5574 merged 1 commit into
mlc-ai:mainfrom
haok1402:0513-zigzag-attention

Conversation

@haok1402
Copy link
Copy Markdown
Collaborator

Changes:

  1. use zigzag layout to ensure causal load balance under context parallelism
  2. overlap flash attention compute over the current block with the send-recv of other KV blocks
  3. apply torch compile to the combination of lse and out

Qwen3-30b-a3b, cp4-s32k, one layer:

  • Total time: 13.87 ms → 8.30 ms (1.67× faster)
  • Forward: 4.14 ms → 2.47 ms (1.68×)
  • Backward: 9.72 ms → 5.83 ms (1.67×)

On the main branch, the workload with naive ring attention is imbalanced. At earlier ranks, NCCL stream took 83.1% of the time, waiting for the last rank to complete, which spends 95.2% on the default stream (i.e. compute), because it can attend to all previous positions.

image

On the current branch, the split between NCCL stream and the default stream is more even, with the transfer of KV blocks happening while the FA4 kernels run.

image

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements zigzag ring attention for context parallelism, which balances the causal workload across ranks by sharding the global sequence into non-contiguous chunks. The changes include updates to the model architectures (DeepSeek and Qwen) to handle the zigzag layout for rotary embeddings, modifications to the training setup for sequence length validation, and a complete rewrite of the ring attention operator to support asynchronous P2P communication and zigzag chunking. Feedback focuses on performance optimizations, specifically regarding the communication bandwidth of gradients in the backward pass and the redundant reconstruction of position IDs during the forward pass.

Comment thread pithtrain/operators/ring_attention.py
Comment thread pithtrain/models/deepseek_v2_lite.py
Comment thread pithtrain/models/qwen3_30b_a3b.py
@haok1402 haok1402 force-pushed the 0513-zigzag-attention branch from b19fa1c to 73d4d95 Compare May 13, 2026 22:38


out: [B, S, H, D], lse: [B, H, S] (float32).
@torch.compile(fullgraph=True)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need the decorator here? Isn't the ring attention always run under an outside @torch.compile decorator?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe need to quickly check the kernel used here without the decorator.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't the ring attention always run under an outside @torch.compile

No. We currently unwrap the torch.compile if ring attention is used.

if self.self_attn.use_ring_attn:
self._forward_attn_compute = self._forward_attn_compute.__wrapped__.__get__(
self, type(self)
)

We apply the torch.compile(fullgraph=True) decorator, mainly to fuse the kernels in combine_partial.

image image

We do see slight reduction with the latency,

  • before this compile : fwd 3.004 ms, bwd 6.174 ms, step 9.178 ms
  • after : fwd 2.467 ms, bwd 5.833 ms, step 8.300 ms

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't the ring attention always run under an outside @torch.compile

No. We currently unwrap the torch.compile if ring attention is used.

if self.self_attn.use_ring_attn:
self._forward_attn_compute = self._forward_attn_compute.__wrapped__.__get__(
self, type(self)
)

@MasterJH5574 I think torch.compile can overall be applicable with ring attention. We should have a follow up PR to address this.

Comment thread pithtrain/operators/ring_attention.py Outdated
Comment thread pithtrain/operators/ring_attention.py Outdated
@haok1402 haok1402 force-pushed the 0513-zigzag-attention branch from 73d4d95 to 9bbf06c Compare May 14, 2026 00:31
Copy link
Copy Markdown
Member

@MasterJH5574 MasterJH5574 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @haok1402!

@MasterJH5574 MasterJH5574 merged commit 597e35c into mlc-ai:main May 14, 2026
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants