From 9bbf06c0d802d7fba57c3dc9c5bce1b0e24012a5 Mon Sep 17 00:00:00 2001 From: haok1402 Date: Wed, 13 May 2026 18:19:37 -0400 Subject: [PATCH] implement ring attetion with zigzag layout --- pithtrain/models/deepseek_v2_lite.py | 23 +- pithtrain/models/qwen3_30b_a3b.py | 22 +- pithtrain/modules/training.py | 9 +- pithtrain/operators/ring_attention.py | 454 +++++++++++++-------- pithtrain/tasks/pretrain_language_model.py | 32 +- tests/operators/test_ring_attention.py | 35 +- 6 files changed, 378 insertions(+), 197 deletions(-) diff --git a/pithtrain/models/deepseek_v2_lite.py b/pithtrain/models/deepseek_v2_lite.py index 541fbd2..1a7e5c6 100644 --- a/pithtrain/models/deepseek_v2_lite.py +++ b/pithtrain/models/deepseek_v2_lite.py @@ -408,7 +408,7 @@ def forward( query_states, key_states, value_states.contiguous(), - softmax_scale=self.softmax_scale, + sm_scale=self.softmax_scale, cp_group=self.cp_group, ) else: @@ -646,6 +646,7 @@ def __init__( self.stage_id = stage_id self.cp_group = cp_group self.cp_rank = cp_group.rank() if cp_group is not None else 0 + self.cp_size = cp_group.size() if cp_group is not None else 1 self.embed_tokens = ( nn.Embedding(config.vocab_size, config.hidden_size) if stage_id == 0 else None ) @@ -714,11 +715,23 @@ def forward( hidden_states = self.embed_tokens(hidden_states) seq_len = hidden_states.shape[1] - offset = self.cp_rank * seq_len - cos, sin = self.rotary_emb(hidden_states, seq_len=offset + seq_len) + # Zigzag CP layout: the local seq_len tokens come from two non-contiguous + # global chunks. Build the global position IDs by concatenating the + # front block and the mirror back block, then gather cos/sin by position. + block = seq_len // 2 + global_seq_len = seq_len * self.cp_size + front_start = self.cp_rank * block + back_start = (2 * self.cp_size - self.cp_rank - 1) * block + position_ids = torch.cat( + [ + torch.arange(front_start, front_start + block, device=hidden_states.device), + torch.arange(back_start, back_start + block, device=hidden_states.device), + ] + ) + cos, sin = self.rotary_emb(hidden_states, seq_len=global_seq_len) position_embeddings = ( - cos[offset : offset + seq_len].unsqueeze(0).to(dtype=hidden_states.dtype), - sin[offset : offset + seq_len].unsqueeze(0).to(dtype=hidden_states.dtype), + cos[position_ids].unsqueeze(0).to(dtype=hidden_states.dtype), + sin[position_ids].unsqueeze(0).to(dtype=hidden_states.dtype), ) for _, layer in self.layers.items(): layer._position_embeddings = position_embeddings diff --git a/pithtrain/models/qwen3_30b_a3b.py b/pithtrain/models/qwen3_30b_a3b.py index ba51059..75d9b23 100644 --- a/pithtrain/models/qwen3_30b_a3b.py +++ b/pithtrain/models/qwen3_30b_a3b.py @@ -406,7 +406,7 @@ def forward( query_states, key_states, value_states, - softmax_scale=self.scaling, + sm_scale=self.scaling, cp_group=self.cp_group, ) @@ -673,6 +673,7 @@ def __init__( self.num_stages = num_stages self.cp_group = cp_group self.cp_rank = cp_group.rank() if cp_group is not None else 0 + self.cp_size = cp_group.size() if cp_group is not None else 1 hidden_size = config.hidden_size num_attention_heads = config.num_attention_heads @@ -765,12 +766,21 @@ def forward( bsz, seq_len, _ = hidden_states.shape - offset = self.cp_rank * seq_len - cos, sin = self.rotary_emb(hidden_states, seq_len=offset + seq_len) - position_embeddings = ( - cos[offset : offset + seq_len].unsqueeze(0), - sin[offset : offset + seq_len].unsqueeze(0), + # Zigzag CP layout: the local seq_len tokens come from two non-contiguous + # global chunks. Build the global position IDs by concatenating the + # front block and the mirror back block, then gather cos/sin by position. + block = seq_len // 2 + global_seq_len = seq_len * self.cp_size + front_start = self.cp_rank * block + back_start = (2 * self.cp_size - self.cp_rank - 1) * block + position_ids = torch.cat( + [ + torch.arange(front_start, front_start + block, device=hidden_states.device), + torch.arange(back_start, back_start + block, device=hidden_states.device), + ] ) + cos, sin = self.rotary_emb(hidden_states, seq_len=global_seq_len) + position_embeddings = (cos[position_ids].unsqueeze(0), sin[position_ids].unsqueeze(0)) for layer_idx_str, layer in self.layers.items(): layer._position_embeddings = position_embeddings diff --git a/pithtrain/modules/training.py b/pithtrain/modules/training.py index 7680ba4..8448fd1 100644 --- a/pithtrain/modules/training.py +++ b/pithtrain/modules/training.py @@ -316,9 +316,12 @@ def setup_model(cfg: TrainingCfg, ctx: TrainingCtx, distributed: DistributedCtx) module_config.ep_size = ep_size assert hasattr(module_config, "hidden_size") assert isinstance(module_config.hidden_size, int) - assert cfg.sequence_length % cp_size == 0, ( - f"sequence_length ({cfg.sequence_length}) must be divisible by context_parallel_size ({cp_size})" - ) + if cfg.sequence_length % (2 * cp_size) != 0: + raise ValueError( + f"sequence_length ({cfg.sequence_length}) must be divisible by " + f"2 * context_parallel_size ({2 * cp_size}); zigzag ring attention " + f"splits the sequence into 2*cp_size equal chunks" + ) hidden_size = module_config.hidden_size diff --git a/pithtrain/operators/ring_attention.py b/pithtrain/operators/ring_attention.py index 2fb58db..47c9916 100644 --- a/pithtrain/operators/ring_attention.py +++ b/pithtrain/operators/ring_attention.py @@ -1,191 +1,327 @@ """ -Ring attention for context parallelism. - -Splits the sequence across CP ranks. Each rank holds Q for its local chunk -and passes K/V around a ring so every rank computes full causal attention. -Partial outputs are combined with online softmax (log-sum-exp rescaling). - -Forward: KV travels next->prev (send to rank+1, recv from rank-1). -Backward: re-uses saved KV chunks and passes the *combined* out/lse to - flash_attn_backward so it reconstructs the correct global attention - weights. dK/dV contributions are redistributed to originating - ranks via distance-based P2P exchange. - -Known limitations ------------------ -* **Causal load imbalance**: With contiguous chunking, rank 0 processes 1 KV - chunk while rank C-1 processes C chunks. Zigzag/striped partitioning would - fix this but is not yet implemented. -* **O(cp_size) KV memory in backward**: All received KV chunks are saved from - the forward pass for reuse in the backward. For large cp_size a ring-based - KV recomputation scheme would reduce this to O(1). +Zigzag ring attention for context parallelism. + +Causal flash attention sharded across cp_size ranks. Two design choices: zigzag chunking for +load balance, and async ring P2P for compute/comm overlap. + +The global sequence is split into 2 * cp_size equal chunks of size block = S / (2 * cp_size). +Rank r holds two of them, chunk r (front block) and chunk 2 * cp_size - r - 1 (back block). +The chunk-to-rank assignment is mirrored; for cp_size = 4 it looks like + + chunk: 0 1 2 3 4 5 6 7 + rank: 0 1 2 3 3 2 1 0 + +Lower-indexed ranks pair one very-early chunk with one very-late chunk; higher-indexed ranks +pair two near-middle chunks. The early chunk has few tokens before it (light causal work), +the late one has many (heavy). They cancel, so every rank ends up doing the same amount of +attention work. + +Q stays on its home rank; K/V rotate one hop per step in the +1 direction. Each step posts +its next-step batch_isend_irecv before launching its flash kernel, so the transfer overlaps +with compute. The backward runs two rings concurrently: K/V rotates as in the forward (we +re-derive partial outputs rather than save cp_size copies), and partial dK/dV rotates in the +same direction so every contribution reaches its originating rank after cp_size hops. + +Q, K, V, and the returned output are all in the zigzag local layout; the caller (data loader +and RoPE) does the permutation. + +The non-obvious part is which flash call covers each step. At step s, rank r holds K/V +originating from kv = (r - s) mod cp_size; the rotated K has its own front block (chunk kv) +and back block (chunk 2 * cp_size - kv - 1). Comparing global chunk positions, exactly one +of three pictures holds: + + step == 0 (kv == r) flash(q, k, v, causal=True) + 1 <= s <= r (kv < r) flash(q, k[:, :block], v[:, :block], causal=False) + s > r (kv > r) flash(q[:, block:], k, v, causal=False) + + step 0: K is local. The four (q_part, k_part) sub-blocks line up with a 2*block local + causal mask, so one causal call handles everything. + + kv < r: K came from a lower-indexed rank, whose chunks live at extreme positions. Only its + front block survives the global causal mask; both halves of Q attend to it fully. + + kv > r: K came from a higher-indexed rank, whose chunks live at central positions. Only + Q's back block attends, sees the full rotated K, and only the back-block positions + of out / lse are updated. + +Every step costs the same: one causal pass on length 2*block, or one non-causal pass on a +2*block-by-block rectangle. """ +from typing import List, Optional, Tuple + import torch -import torch.distributed as dist +import torch.nn.functional as F from flash_attn.cute.interface import _flash_attn_bwd, _flash_attn_fwd +from torch.distributed import ( + P2POp, + ProcessGroup, + Work, + batch_isend_irecv, + get_global_rank, + get_rank, + get_world_size, + irecv, + isend, +) -def _ring_send_recv_kv( - send_k: torch.Tensor, - send_v: torch.Tensor, - next_global: int, - prev_global: int, -): - assert send_k.is_contiguous() and send_v.is_contiguous(), ( - "ring attention send_k or send_v is not contiguous" - ) - recv_k = torch.empty_like(send_k) - recv_v = torch.empty_like(send_v) - ops = [ - dist.P2POp(dist.isend, send_k, next_global), - dist.P2POp(dist.isend, send_v, next_global), - dist.P2POp(dist.irecv, recv_k, prev_global), - dist.P2POp(dist.irecv, recv_v, prev_global), - ] - reqs = dist.batch_isend_irecv(ops) - # NCCL runs on its own internal CUDA stream - for req in reqs: - req.wait() - return recv_k, recv_v +def post_ring_kv( + k: torch.Tensor, + v: torch.Tensor, + cp_group: ProcessGroup, + dst: int, + src: int, + k_recv: Optional[torch.Tensor] = None, + v_recv: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, List[Work]]: + """ + Async (K, V) ring hop. Pre-allocated recv buffers let the backward recycle just-sent + dK/dV buffers as the next iteration's recv slots. + """ + if not (k.is_contiguous() and v.is_contiguous()): + raise ValueError("ring P2P requires contiguous send buffers") + if k_recv is None: + k_recv = torch.empty_like(k) + if v_recv is None: + v_recv = torch.empty_like(v) + ops = [] + ops.append(P2POp(isend, k, dst, group=cp_group)) + ops.append(P2POp(isend, v, dst, group=cp_group)) + ops.append(P2POp(irecv, k_recv, src, group=cp_group)) + ops.append(P2POp(irecv, v_recv, src, group=cp_group)) + work = batch_isend_irecv(ops) + return k_recv, v_recv, work + +def wait_ring(work: List[Work]) -> None: + for req in work: + req.wait() -def _online_softmax_combine(out1, lse1, out2, lse2): - """Combine two partial attention outputs via online softmax rescaling. - out: [B, S, H, D], lse: [B, H, S] (float32). +@torch.compile(fullgraph=True) +def combine_partial( + out: Optional[torch.Tensor], + lse: Optional[torch.Tensor], + partial_out: torch.Tensor, + partial_lse: torch.Tensor, + start: int = 0, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Online-softmax merge of a partial flash output into the running fp32 accumulator. When + start > 0 only positions [start:] are updated; the first call must have start == 0. """ - max_lse = torch.maximum(lse1, lse2) - exp1 = torch.exp(lse1 - max_lse) - exp2 = torch.exp(lse2 - max_lse) - # [B, H, S] -> [B, S, H, 1] for broadcast with [B, S, H, D] - e1 = exp1.transpose(1, 2).unsqueeze(-1) - e2 = exp2.transpose(1, 2).unsqueeze(-1) - new_out = (e1 * out1 + e2 * out2) / (e1 + e2) - new_lse = max_lse + torch.log(exp1 + exp2) - return new_out, new_lse + partial_out = partial_out.to(torch.float32) + partial_lse = partial_lse.transpose(-2, -1).unsqueeze(-1) + if out is None: + if start != 0: + raise ValueError("first combine_partial call must update the full sequence") + return partial_out, partial_lse + if start == 0: + weight = torch.sigmoid(partial_lse - lse) + new_out = out + weight * (partial_out - out) + new_lse = lse + F.softplus(partial_lse - lse) + return new_out, new_lse + cur_out = out[:, start:] + cur_lse = lse[:, start:] + weight = torch.sigmoid(partial_lse - cur_lse) + out[:, start:] = cur_out + weight * (partial_out - cur_out) + lse[:, start:] = cur_lse + F.softplus(partial_lse - cur_lse) + return out, lse -class RingAttentionFunc(torch.autograd.Function): - """Causal ring attention with flash-attn kernels.""" +def zigzag_forward( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + sm_scale: float, + cp_group: ProcessGroup, +) -> Tuple[torch.Tensor, torch.Tensor]: + cp_rank, cp_size = get_rank(cp_group), get_world_size(cp_group) + dst = get_global_rank(cp_group, (cp_rank + 1) % cp_size) + src = get_global_rank(cp_group, (cp_rank - 1) % cp_size) + block = q.shape[1] // 2 + q_back = q[:, block:] - @staticmethod - def forward(ctx, q, k, v, softmax_scale, cp_rank, cp_size, global_ranks): - B, S, H, _ = q.shape - DV = v.shape[-1] - combined_out = torch.zeros((B, S, H, DV), dtype=q.dtype, device=q.device) - combined_lse = torch.full( - (B, H, S), torch.finfo(torch.float32).min, device=q.device, dtype=torch.float32 - ) + out: Optional[torch.Tensor] = None + lse: Optional[torch.Tensor] = None + next_k: Optional[torch.Tensor] = None + next_v: Optional[torch.Tensor] = None + kv_work: Optional[List[Work]] = None - next_global = global_ranks[(cp_rank + 1) % cp_size] - prev_global = global_ranks[(cp_rank - 1) % cp_size] - assert k.is_contiguous() and v.is_contiguous(), "ring attention k or v is not contiguous" - cur_k, cur_v = k, v - saved_k, saved_v, saved_ranks = [], [], [] + for step in range(cp_size): + if step + 1 < cp_size: + next_k, next_v, kv_work = post_ring_kv(k, v, cp_group, dst, src) + if step == 0: + partial_out, partial_lse = _flash_attn_fwd( + q, k, v, softmax_scale=sm_scale, causal=True, return_lse=True + ) + out, lse = combine_partial(out, lse, partial_out, partial_lse) + elif step <= cp_rank: + partial_out, partial_lse = _flash_attn_fwd( + q, + k[:, :block], + v[:, :block], + softmax_scale=sm_scale, + causal=False, + return_lse=True, + ) + out, lse = combine_partial(out, lse, partial_out, partial_lse) + else: + partial_out, partial_lse = _flash_attn_fwd( + q_back, k, v, softmax_scale=sm_scale, causal=False, return_lse=True + ) + out, lse = combine_partial(out, lse, partial_out, partial_lse, start=block) + if step + 1 < cp_size: + wait_ring(kv_work) + k, v = next_k, next_v - for step in range(cp_size): - kv_rank = (cp_rank - step) % cp_size + out = out.to(q.dtype) + lse = lse.squeeze(-1).transpose(1, 2).contiguous() + return out, lse - if kv_rank <= cp_rank: - use_causal = kv_rank == cp_rank - step_out, step_lse = _flash_attn_fwd( - q, cur_k, cur_v, softmax_scale=softmax_scale, causal=use_causal, return_lse=True - ) - combined_out, combined_lse = _online_softmax_combine( - combined_out, combined_lse, step_out, step_lse - ) - saved_k.append(cur_k) - saved_v.append(cur_v) - saved_ranks.append(kv_rank) - if step < cp_size - 1: - cur_k, cur_v = _ring_send_recv_kv(cur_k, cur_v, next_global, prev_global) +def zigzag_backward( + dout: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + lse: torch.Tensor, + sm_scale: float, + cp_group: ProcessGroup, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + cp_rank, cp_size = get_rank(cp_group), get_world_size(cp_group) + dst = get_global_rank(cp_group, (cp_rank + 1) % cp_size) + src = get_global_rank(cp_group, (cp_rank - 1) % cp_size) + block = q.shape[1] // 2 - combined_out = combined_out.to(q.dtype) + dout_back = dout[:, block:].contiguous() + q_back = q[:, block:].contiguous() + out_back = out[:, block:].contiguous() + lse_back = lse[:, :, block:].contiguous() - assert combined_lse.is_contiguous() - ctx.save_for_backward(q, combined_out, combined_lse, *saved_k, *saved_v) - ctx.softmax_scale = softmax_scale - ctx.cp_rank = cp_rank - ctx.cp_size = cp_size - ctx.global_ranks = global_ranks - ctx.saved_ranks = saved_ranks - ctx.n_saved = len(saved_ranks) - return combined_out + dq: Optional[torch.Tensor] = None + dk: Optional[torch.Tensor] = None + dv: Optional[torch.Tensor] = None + next_k: Optional[torch.Tensor] = None + next_v: Optional[torch.Tensor] = None + kv_work: Optional[List[Work]] = None + incoming_dk: Optional[torch.Tensor] = None + incoming_dv: Optional[torch.Tensor] = None + grad_recv_slot_k: Optional[torch.Tensor] = None + grad_recv_slot_v: Optional[torch.Tensor] = None + grad_work: Optional[List[Work]] = None - @staticmethod - def backward(ctx, dout): - n = ctx.n_saved - saved = ctx.saved_tensors - q, combined_out, combined_lse = saved[0], saved[1], saved[2] - all_k = list(saved[3 : 3 + n]) - all_v = list(saved[3 + n : 3 + 2 * n]) - cp_rank, cp_size, global_ranks = ctx.cp_rank, ctx.cp_size, ctx.global_ranks - - dq = torch.zeros_like(q) - local_dk = torch.zeros_like(all_k[0]) - local_dv = torch.zeros_like(all_v[0]) - remote_dk, remote_dv = {}, {} - - for i, kv_rank in enumerate(ctx.saved_ranks): - dq_s, dk_s, dv_s = _flash_attn_bwd( - q, - all_k[i], - all_v[i], - combined_out, - dout, - combined_lse, - softmax_scale=ctx.softmax_scale, - causal=(kv_rank == cp_rank), + for step in range(cp_size): + if step + 1 < cp_size: + next_k, next_v, kv_work = post_ring_kv(k, v, cp_group, dst, src) + if step == 0: + dq_step, dk_step, dv_step = _flash_attn_bwd( + q, k, v, out, dout, lse, softmax_scale=sm_scale, causal=True ) - dq += dq_s - if kv_rank == cp_rank: - local_dk += dk_s - local_dv += dv_s + dq = dq_step.to(torch.float32) + dk = dk_step.to(torch.float32) + dv = dv_step.to(torch.float32) + else: + if step <= cp_rank: + dq_step, dk_step, dv_step = _flash_attn_bwd( + q, + k[:, :block], + v[:, :block], + out, + dout, + lse, + softmax_scale=sm_scale, + causal=False, + ) + dq += dq_step else: - remote_dk[kv_rank] = dk_s - remote_dv[kv_rank] = dv_s - - # Redistribute dK/dV to originating ranks via P2P. - # At distance d: rank r sends dK for rank (r-d) and receives dK from rank (r+d). - for d in range(1, cp_size): - target, source = cp_rank - d, cp_rank + d - ops, recv_dk_buf, recv_dv_buf = [], None, None - - if target >= 0 and target in remote_dk: - dst = global_ranks[target] - assert remote_dk[target].is_contiguous() and remote_dv[target].is_contiguous(), ( - "ring attention remote_dk or remote_dv is not contiguous" + dq_step, dk_step, dv_step = _flash_attn_bwd( + q_back, + k, + v, + out_back, + dout_back, + lse_back, + softmax_scale=sm_scale, + causal=False, ) - ops.append(dist.P2POp(dist.isend, remote_dk[target], dst)) - ops.append(dist.P2POp(dist.isend, remote_dv[target], dst)) - if source < cp_size: - src = global_ranks[source] - recv_dk_buf = torch.empty_like(local_dk) - recv_dv_buf = torch.empty_like(local_dv) - ops.append(dist.P2POp(dist.irecv, recv_dk_buf, src)) - ops.append(dist.P2POp(dist.irecv, recv_dv_buf, src)) + dq[:, block:] += dq_step + # Adopt the previous hop's dK/dV as the working accumulator. + # The buffers we shipped one hop ago are free to recycle. + wait_ring(grad_work) + grad_recv_slot_k, grad_recv_slot_v = dk, dv + dk, dv = incoming_dk, incoming_dv + if step <= cp_rank: + dk[:, :block] += dk_step + dv[:, :block] += dv_step + else: + dk += dk_step + dv += dv_step + if step + 1 < cp_size: + wait_ring(kv_work) + k, v = next_k, next_v + incoming_dk, incoming_dv, grad_work = post_ring_kv( + dk, dv, cp_group, dst, src, grad_recv_slot_k, grad_recv_slot_v + ) - if ops: - for req in dist.batch_isend_irecv(ops): - req.wait() - if recv_dk_buf is not None: - local_dk += recv_dk_buf - local_dv += recv_dv_buf + wait_ring(grad_work) + return dq.to(q.dtype), incoming_dk.to(q.dtype), incoming_dv.to(q.dtype) + + +class ZigzagRingAttention(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, sm_scale, cp_group): + if not (k.is_contiguous() and v.is_contiguous()): + raise ValueError("ring attention requires contiguous k and v") + if q.shape[1] % 2: + raise ValueError(f"zigzag layout needs even local seq len, got {q.shape[1]}") + out, lse = zigzag_forward(q, k, v, sm_scale, cp_group) + ctx.save_for_backward(q, k, v, out, lse) + ctx.sm_scale = sm_scale + ctx.cp_group = cp_group + return out - return dq, local_dk, local_dv, None, None, None, None + @staticmethod + def backward(ctx, dout): + q, k, v, out, lse = ctx.saved_tensors + dq, dk, dv = zigzag_backward(dout, q, k, v, out, lse, ctx.sm_scale, ctx.cp_group) + return dq, dk, dv, None, None def ring_attention_func( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - softmax_scale: float, - cp_group: dist.ProcessGroup, -): - """Causal ring attention across context-parallel ranks.""" - cp_rank = cp_group.rank() - cp_size = cp_group.size() - global_ranks = [dist.distributed_c10d.get_global_rank(cp_group, r) for r in range(cp_size)] - return RingAttentionFunc.apply(q, k, v, softmax_scale, cp_rank, cp_size, global_ranks) + sm_scale: float, + cp_group: ProcessGroup, +) -> torch.Tensor: + """ + Causal zigzag ring attention across context-parallel ranks. + + Parameters + ---------- + q : torch.Tensor + Query tensor of shape [batch, S_local, num_q_heads, head_dim] in zigzag local + layout, where S_local = S / cp_size for a global sequence of length S. + k : torch.Tensor + Key tensor of shape [batch, S_local, num_kv_heads, head_dim] in the same zigzag + layout as q. Must be contiguous; rotated around the ring during forward and + re-rotated during backward. + v : torch.Tensor + Value tensor of shape [batch, S_local, num_kv_heads, head_dim_v] in the same + zigzag layout as q. Must be contiguous. + sm_scale : float + Softmax scale, typically head_dim ** -0.5. + cp_group : torch.distributed.ProcessGroup + Context-parallel process group. Must contain at least two ranks; the single-rank + case is handled upstream by skipping the ring entirely. + + Returns + ------- + torch.Tensor + Attention output of shape [batch, S_local, num_q_heads, head_dim_v] in q.dtype, + returned in the same zigzag local layout as q. + """ + return ZigzagRingAttention.apply(q, k, v, sm_scale, cp_group) diff --git a/pithtrain/tasks/pretrain_language_model.py b/pithtrain/tasks/pretrain_language_model.py index 59da1cb..c6bedd9 100644 --- a/pithtrain/tasks/pretrain_language_model.py +++ b/pithtrain/tasks/pretrain_language_model.py @@ -90,27 +90,35 @@ def get_global_batch( local_batch_size = global_batch_size // (dp_size * ep_size) start0 = step * global_batch_size + (dp_rank * ep_size + ep_rank) * micro_batch_size - # Compute the CP sub-range so we only read the needed tokens from mmap. + # Zigzag context-parallel sharding. We split the global sequence into + # 2*cp_size equal chunks and assign rank r the pair (r, 2*cp_size-r-1). + # The local sequence is the concatenation of the "front" chunk and the + # mirror "back" chunk, balancing the causal workload across CP ranks + # (see pithtrain/operators/ring_attention.py for the matching attention + # implementation). For cp_size == 1 this reduces to a contiguous read. cp_size = ctx.distributed.cp_size - if cp_size > 1: - cp_rank = ctx.distributed.cp_rank - local_seq_len = sequence_length // cp_size - seq_offset = cp_rank * local_seq_len - else: - local_seq_len = sequence_length - seq_offset = 0 + cp_rank = ctx.distributed.cp_rank + block = sequence_length // (2 * cp_size) + local_seq_len = 2 * block + front_offset = cp_rank * block + back_offset = (2 * cp_size - cp_rank - 1) * block # single allocation on host, then one HtoD transfer per tensor local_tokens = torch.empty((local_batch_size, local_seq_len), dtype=torch.long) local_labels = torch.empty((local_batch_size, local_seq_len), dtype=torch.long) - # fill in one pass: k iterates over our rank-local batch rows + # fill in one pass: k iterates over our rank-local batch rows. Each sample + # is two memmap reads (front block + back block) followed by an in-place + # concat into the pre-allocated host buffer. for k in range(local_batch_size): acc, off = divmod(k, micro_batch_size) index = start0 + acc * effective_batch_size + off - # get_chunk reads smaller chunk of the sequence if cp_size > 1 - tokens, labels = dataset.get_chunk(index, seq_offset, local_seq_len) - local_tokens[k], local_labels[k] = tokens, labels + tokens_a, labels_a = dataset.get_chunk(index, front_offset, block) + tokens_b, labels_b = dataset.get_chunk(index, back_offset, block) + local_tokens[k, :block] = tokens_a + local_tokens[k, block:] = tokens_b + local_labels[k, :block] = labels_a + local_labels[k, block:] = labels_b local_tokens = local_tokens.to(device, non_blocking=True) local_labels = local_labels.to(device, non_blocking=True) diff --git a/tests/operators/test_ring_attention.py b/tests/operators/test_ring_attention.py index 0fe528b..25dc22a 100644 --- a/tests/operators/test_ring_attention.py +++ b/tests/operators/test_ring_attention.py @@ -29,22 +29,29 @@ class Result: dv: torch.Tensor +def extract_zigzag(x: torch.Tensor, cp_rank: int, cp_size: int) -> torch.Tensor: + """ + Extract this rank's zigzag-local slice along the sequence dim (dim=1). + + The global sequence is split into 2*cp_size equal chunks; rank r holds + chunk r (front) concatenated with chunk 2*cp_size - r - 1 (mirror back). + """ + chunks = x.chunk(2 * cp_size, dim=1) + return torch.cat([chunks[cp_rank], chunks[2 * cp_size - cp_rank - 1]], dim=1).contiguous() + + def record(ctx: DistributedCtx, req: Request) -> tuple[Result, Result]: """ Record the forward output and the input gradients dQ, dK, dV for both the baseline and the implementation. The baseline is flash_attn_func run on the - full sequence with no CP communication, effectively the full sequence at - once; the implementation is ring_attention_func run on this rank's sequence - shard with K/V exchanged around the CP ring. + full sequence with no CP communication; the implementation is the zigzag + ring_attention_func run on this rank's zigzag-local slice of Q/K/V with K/V + rotated around the CP ring during the forward and backward passes. """ cp_group = ctx.device_mesh.get_group("cp") cp_rank, cp_size = cp_group.rank(), cp_group.size() device = torch.cuda.current_device() softmax_scale = req.D**-0.5 - S_local = req.S // cp_size - - def shard(x: torch.Tensor) -> torch.Tensor: - return x[:, cp_rank * S_local : (cp_rank + 1) * S_local] torch.manual_seed(42) q_full = torch.randn(req.B, req.S, req.HQ, req.D, device=device, dtype=torch.bfloat16) @@ -56,11 +63,15 @@ def shard(x: torch.Tensor) -> torch.Tensor: v_ref = v_full.clone().requires_grad_(True) out_ref = flash_attn_func(q_ref, k_ref, v_ref, softmax_scale, causal=True) out_ref.sum().backward() - ref = Result(shard(out_ref), shard(q_ref.grad), shard(k_ref.grad), shard(v_ref.grad)) - - q_imp = shard(q_full).clone().requires_grad_(True) - k_imp = shard(k_full).clone().requires_grad_(True) - v_imp = shard(v_full).clone().requires_grad_(True) + out_ref = extract_zigzag(out_ref, cp_rank, cp_size) + dq_ref = extract_zigzag(q_ref.grad, cp_rank, cp_size) + dk_ref = extract_zigzag(k_ref.grad, cp_rank, cp_size) + dv_ref = extract_zigzag(v_ref.grad, cp_rank, cp_size) + ref = Result(out_ref, dq_ref, dk_ref, dv_ref) + + q_imp = extract_zigzag(q_full, cp_rank, cp_size).clone().requires_grad_(True) + k_imp = extract_zigzag(k_full, cp_rank, cp_size).clone().requires_grad_(True) + v_imp = extract_zigzag(v_full, cp_rank, cp_size).clone().requires_grad_(True) out_imp = ring_attention_func(q_imp, k_imp, v_imp, softmax_scale, cp_group) out_imp.sum().backward() imp = Result(out_imp, q_imp.grad, k_imp.grad, v_imp.grad)