diff --git a/benchmarks/operators/bench_ring_attention.py b/benchmarks/operators/bench_ring_attention.py new file mode 100644 index 0000000..e242596 --- /dev/null +++ b/benchmarks/operators/bench_ring_attention.py @@ -0,0 +1,93 @@ +"""Benchmark forward and backward latency of ring attention.""" + +import json +import re +import sys +from pathlib import Path +from types import SimpleNamespace + +import torch + +from pithtrain.modules.distributed import DistributedCfg, DistributedCtx, distributed_context +from pithtrain.operators.ring_attention import ring_attention_func + + +def parse_scenario(scenario: str) -> tuple[dict, int, int]: + m = re.match(r"^(.+)-cp(\d+)-s(\d+)k$", scenario) + if not m: + raise ValueError(f"invalid scenario '{scenario}', expected -cp-sk") + model = m.group(1) + with open(Path(f"examples/pretrain_language_model/{model}/config.json")) as f: + config = json.load(f) + return config, int(m.group(2)), int(m.group(3)) * 1024 + + +def run(ctx: DistributedCtx, scenario: str, config: dict, cp_size: int, S: int) -> None: + B = 1 + WARMUP, NITERS = 25, 100 + HQ, HK = config["num_attention_heads"], config["num_key_value_heads"] + D = config["head_dim"] + + cp_group = ctx.device_mesh.get_group("cp") + device = torch.cuda.current_device() + softmax_scale = D**-0.5 + S_local = S // cp_size + + torch.manual_seed(42) + kwargs = dict(device=device, dtype=torch.bfloat16) + q = torch.randn(B, S_local, HQ, D, requires_grad=True, **kwargs) + k = torch.randn(B, S_local, HK, D, requires_grad=True, **kwargs) + v = torch.randn(B, S_local, HK, D, requires_grad=True, **kwargs) + grad_out = torch.randn(B, S_local, HQ, D, **kwargs) + + def run_once() -> None: + q.grad, k.grad, v.grad = None, None, None + out = ring_attention_func(q, k, v, softmax_scale, cp_group) + out.backward(grad_out) + + for _ in range(WARMUP): + run_once() + torch.cuda.synchronize() + + # Timed forward/backward, separated by CUDA events. + fwd_total_ms = 0.0 + bwd_total_ms = 0.0 + for _ in range(NITERS): + q.grad, k.grad, v.grad = None, None, None + fwd_start = torch.cuda.Event(enable_timing=True) + fwd_end = torch.cuda.Event(enable_timing=True) + bwd_end = torch.cuda.Event(enable_timing=True) + fwd_start.record() + out = ring_attention_func(q, k, v, softmax_scale, cp_group) + fwd_end.record() + out.backward(grad_out) + bwd_end.record() + torch.cuda.synchronize() + fwd_total_ms += fwd_start.elapsed_time(fwd_end) + bwd_total_ms += fwd_end.elapsed_time(bwd_end) + + fwd_avg = fwd_total_ms / NITERS + bwd_avg = bwd_total_ms / NITERS + + if ctx.rank == 0: + print(f"{scenario} | fwd: {fwd_avg:7.3f} ms , bwd: {bwd_avg:7.3f} ms", flush=True) + torch.distributed.barrier() + + # Nsys profile capture with one iteration. + torch.cuda.synchronize() + torch.cuda.profiler.start() + run_once() + torch.cuda.synchronize() + torch.cuda.profiler.stop() + + +if __name__ == "__main__": + scenario = sys.argv[1] + config, cp_size, S = parse_scenario(scenario) + + cfg = DistributedCfg() + cfg.context_parallel_size = cp_size + parent_cfg = SimpleNamespace(distributed=cfg) + parent_ctx = SimpleNamespace(distributed=DistributedCtx()) + with distributed_context(parent_cfg, parent_ctx) as ctx: + run(ctx, scenario, config, cp_size, S) diff --git a/benchmarks/operators/bench_ring_attention.sh b/benchmarks/operators/bench_ring_attention.sh new file mode 100755 index 0000000..aca1049 --- /dev/null +++ b/benchmarks/operators/bench_ring_attention.sh @@ -0,0 +1,44 @@ +#!/bin/bash +# Benchmark ring attention. Always captures an nsys profile of the final iteration. +# +# Usage: +# bash benchmarks/operators/bench_ring_attention.sh qwen3-30b-a3b-cp4-s32k + +set -euo pipefail +export OMP_NUM_THREADS=8 + +SCRIPT=benchmarks/operators/bench_ring_attention.py +OUTDIR=workspace/benchmarks/operators +mkdir -p $OUTDIR + +if [ $# -lt 1 ]; then + echo "Usage: $0 (e.g. qwen3-30b-a3b-cp4-s32k)" >&2 + exit 1 +fi +SCENARIO=$1 + +# Extract cp_size from the scenario for torchrun's --nproc-per-node. +if [[ ! $SCENARIO =~ -cp([0-9]+)- ]]; then + echo "Scenario '$SCENARIO' missing -cp- segment" >&2 + exit 1 +fi +NPROC=${BASH_REMATCH[1]} + +NSYS_ARGS=() +NSYS_ARGS+=(profile) +NSYS_ARGS+=(--stats=false) +NSYS_ARGS+=(--trace=cuda,osrt,nvtx) +NSYS_ARGS+=(--force-overwrite=true) +NSYS_ARGS+=(--output=$OUTDIR/ring_attention.$SCENARIO) +NSYS_ARGS+=(--cuda-graph-trace=node) +NSYS_ARGS+=(--capture-range=cudaProfilerApi) +NSYS_ARGS+=(--capture-range-end=stop-shutdown) +NSYS_ARGS+=(--delay=0) + +TORCHRUN_ARGS=() +TORCHRUN_ARGS+=(--nnodes=1) +TORCHRUN_ARGS+=(--nproc-per-node=$NPROC) +TORCHRUN_ARGS+=(--rdzv-backend=c10d) +TORCHRUN_ARGS+=(--rdzv-endpoint=localhost:15213) + +nsys ${NSYS_ARGS[@]} torchrun ${TORCHRUN_ARGS[@]} $SCRIPT $SCENARIO diff --git a/pithtrain/models/deepseek_v2_lite.py b/pithtrain/models/deepseek_v2_lite.py index 2467ee5..541fbd2 100644 --- a/pithtrain/models/deepseek_v2_lite.py +++ b/pithtrain/models/deepseek_v2_lite.py @@ -19,7 +19,7 @@ from pithtrain.modules.load_balance import MoELoadBalanceLossInjector, MoELoadBalanceLossTracker from pithtrain.operators.ep_dispatch import moe_ep_prepare_dispatch from pithtrain.operators.flash_attn_v4 import mla_flash_attn_func -from pithtrain.operators.ring_attention.standard import ring_attention_func +from pithtrain.operators.ring_attention import ring_attention_func from pithtrain.operators.silu_mul import silu_mul from pithtrain.operators.token_scatter import ( padded_index_gather, diff --git a/pithtrain/models/qwen3_30b_a3b.py b/pithtrain/models/qwen3_30b_a3b.py index 4309f6f..ba51059 100644 --- a/pithtrain/models/qwen3_30b_a3b.py +++ b/pithtrain/models/qwen3_30b_a3b.py @@ -17,7 +17,7 @@ from pithtrain.modules.load_balance import MoELoadBalanceLossInjector, MoELoadBalanceLossTracker from pithtrain.operators.ep_dispatch import moe_ep_prepare_dispatch from pithtrain.operators.flash_attn_v4 import flash_attn_func -from pithtrain.operators.ring_attention.standard import ring_attention_func +from pithtrain.operators.ring_attention import ring_attention_func from pithtrain.operators.silu_mul import silu_mul from pithtrain.operators.token_scatter import ( padded_index_gather, diff --git a/pithtrain/operators/ring_attention/standard.py b/pithtrain/operators/ring_attention.py similarity index 100% rename from pithtrain/operators/ring_attention/standard.py rename to pithtrain/operators/ring_attention.py diff --git a/pithtrain/operators/ring_attention/__init__.py b/pithtrain/operators/ring_attention/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/operators/mla/__init__.py b/tests/operators/mla/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/operators/test_ring_attention.py b/tests/operators/test_ring_attention.py new file mode 100644 index 0000000..0fe528b --- /dev/null +++ b/tests/operators/test_ring_attention.py @@ -0,0 +1,89 @@ +"""Test the correctness of ring attention under context parallelism.""" + +from dataclasses import dataclass, fields + +import pytest +import torch + +from pithtrain.modules.distributed import DistributedCfg, DistributedCtx +from pithtrain.operators.flash_attn_v4 import flash_attn_func +from pithtrain.operators.ring_attention import ring_attention_func +from tests.utilities import cosine_error, launch + + +@dataclass +class Request: + B: int + S: int + HQ: int + HK: int + D: int + atol: float = 1e-5 + + +@dataclass +class Result: + out: torch.Tensor + dq: torch.Tensor + dk: torch.Tensor + dv: torch.Tensor + + +def record(ctx: DistributedCtx, req: Request) -> tuple[Result, Result]: + """ + Record the forward output and the input gradients dQ, dK, dV for both the + baseline and the implementation. The baseline is flash_attn_func run on the + full sequence with no CP communication, effectively the full sequence at + once; the implementation is ring_attention_func run on this rank's sequence + shard with K/V exchanged around the CP ring. + """ + cp_group = ctx.device_mesh.get_group("cp") + cp_rank, cp_size = cp_group.rank(), cp_group.size() + device = torch.cuda.current_device() + softmax_scale = req.D**-0.5 + S_local = req.S // cp_size + + def shard(x: torch.Tensor) -> torch.Tensor: + return x[:, cp_rank * S_local : (cp_rank + 1) * S_local] + + torch.manual_seed(42) + q_full = torch.randn(req.B, req.S, req.HQ, req.D, device=device, dtype=torch.bfloat16) + k_full = torch.randn(req.B, req.S, req.HK, req.D, device=device, dtype=torch.bfloat16) + v_full = torch.randn(req.B, req.S, req.HK, req.D, device=device, dtype=torch.bfloat16) + + q_ref = q_full.clone().requires_grad_(True) + k_ref = k_full.clone().requires_grad_(True) + v_ref = v_full.clone().requires_grad_(True) + out_ref = flash_attn_func(q_ref, k_ref, v_ref, softmax_scale, causal=True) + out_ref.sum().backward() + ref = Result(shard(out_ref), shard(q_ref.grad), shard(k_ref.grad), shard(v_ref.grad)) + + q_imp = shard(q_full).clone().requires_grad_(True) + k_imp = shard(k_full).clone().requires_grad_(True) + v_imp = shard(v_full).clone().requires_grad_(True) + out_imp = ring_attention_func(q_imp, k_imp, v_imp, softmax_scale, cp_group) + out_imp.sum().backward() + imp = Result(out_imp, q_imp.grad, k_imp.grad, v_imp.grad) + + return ref, imp + + +def verify(ctx: DistributedCtx, req: Request) -> None: + ref, imp = record(ctx, req) + for f in fields(ref): + error = cosine_error(getattr(ref, f.name), getattr(imp, f.name)) + if error >= req.atol: + raise AssertionError(f"{f.name} diverged: {error=:.2e} >= {req.atol=}") + + +REQUESTS = [] +REQUESTS.append(pytest.param(2, Request(B=1, S=2048, HQ=4, HK=4, D=64), id="CP2-MHA-S2048")) +REQUESTS.append(pytest.param(2, Request(B=2, S=2048, HQ=8, HK=2, D=64), id="CP2-GQA-S2048")) +REQUESTS.append(pytest.param(4, Request(B=1, S=4096, HQ=12, HK=4, D=128), id="CP4-GQA-S4096")) + + +@pytest.mark.parametrize("cp_size,req", REQUESTS) +def test_ring_attention_vs_dense(cp_size: int, req: Request) -> None: + cfg = DistributedCfg() + cfg.context_parallel_size = cp_size + launch(cfg, verify, req) diff --git a/tests/operators/utilities.py b/tests/operators/utilities.py deleted file mode 100644 index 07b6cbe..0000000 --- a/tests/operators/utilities.py +++ /dev/null @@ -1,42 +0,0 @@ -import torch - - -# fmt: off -def assert_close( - actual: torch.Tensor, expected: torch.Tensor, - rtol: float = 1.6e-2, atol: float = 1e-5, otol: float = 0.0, -) -> None: - """ - Assert two tensors are close, with detailed error message on failure. - - Parameters: - ---------- - actual: torch.Tensor - The actual tensor. - expected: torch.Tensor - The expected tensor. - rtol: float - Relative tolerance. Default value is 1.6e-2. - atol: float - Absolute tolerance. Default value is 1e-5. - otol: float - Outlier tolerance: maximum allowed percentage of mismatched elements. - It is between 0.0 and 1.0 with default value 0.0. - - Raises: - ------ - AssertionError - If tensors are not close within specified tolerances. - """ - __tracebackhide__ = True - close = torch.isclose(actual, expected, rtol=rtol, atol=atol) - ofrac = (~close).float().mean().item() - - if ofrac > otol: - lines = [] - lines.append("Tensor-likes are not close!") - lines.append(f" Shape: {tuple(actual.shape)}") - lines.append(f" Mismatch rate: {ofrac:.2%} (allowed: {otol:.2%})") - lines.append(f" Greatest absolute difference: {(actual - expected).abs().max():.6g}") - raise AssertionError("\n".join(lines)) -# fmt: on diff --git a/tests/test_ring_attention.py b/tests/test_ring_attention.py deleted file mode 100644 index 9ec563a..0000000 --- a/tests/test_ring_attention.py +++ /dev/null @@ -1,120 +0,0 @@ -""" -Correctness test for ring attention (context parallelism). - -Compares the output and gradients of ring_attention_func (split across CP ranks) -against a single flash_attn_func call on the full un-split sequence. - -Launch with: - torchrun --nproc-per-node=2 tests/test_ring_attention.py -""" - -import os -import sys - -import torch -import torch.distributed as dist - -from pithtrain.operators.flash_attn_v4 import flash_attn_func -from pithtrain.operators.ring_attention.standard import ring_attention_func - - -def setup(): - dist.init_process_group("nccl") - rank = int(os.environ["LOCAL_RANK"]) - torch.cuda.set_device(rank) - return rank, dist.get_world_size() - - -def _cosine_error(a, b): - """1 - cosine similarity: 0 = identical, 2 = opposite.""" - a, b = a.double().flatten(), b.double().flatten() - return 1.0 - torch.dot(a, b) / (a.norm() * b.norm() + 1e-12) - - -def _rel_error(a, b): - """Max element-wise relative error.""" - a, b = a.float(), b.float() - return ((a - b).abs() / (b.abs().clamp(min=1e-6))).max().item() - - -def test_forward_backward(B, S, H, D, num_kv_heads, cp_group, dtype=torch.bfloat16): - """Compare ring attention on split sequence vs standard attention on full sequence.""" - rank = cp_group.rank() - cp_size = cp_group.size() - device = torch.cuda.current_device() - scale = D**-0.5 - S_local = S // cp_size - - torch.manual_seed(42) - q_full = torch.randn(B, S, H, D, device=device, dtype=dtype) - k_full = torch.randn(B, S, num_kv_heads, D, device=device, dtype=dtype) - v_full = torch.randn(B, S, num_kv_heads, D, device=device, dtype=dtype) - - # --- Reference: standard causal flash attention on full sequence --- - q_ref = q_full.clone().requires_grad_(True) - k_ref = k_full.clone().requires_grad_(True) - v_ref = v_full.clone().requires_grad_(True) - out_ref = flash_attn_func(q_ref, k_ref, v_ref, softmax_scale=scale, causal=True) - loss_ref = out_ref.sum() - loss_ref.backward() - - # --- Ring attention: each rank gets its chunk --- - q_chunk = q_full[:, rank * S_local : (rank + 1) * S_local].clone().requires_grad_(True) - k_chunk = k_full[:, rank * S_local : (rank + 1) * S_local].clone().requires_grad_(True) - v_chunk = v_full[:, rank * S_local : (rank + 1) * S_local].clone().requires_grad_(True) - out_ring = ring_attention_func( - q_chunk, k_chunk, v_chunk, softmax_scale=scale, cp_group=cp_group - ) - loss_ring = out_ring.sum() - loss_ring.backward() - - # --- Compare forward output --- - out_ref_chunk = out_ref[:, rank * S_local : (rank + 1) * S_local] - fwd_cos = _cosine_error(out_ring, out_ref_chunk).item() - - # --- Compare gradients (cosine similarity) --- - dq_ref_chunk = q_ref.grad[:, rank * S_local : (rank + 1) * S_local] - dk_ref_chunk = k_ref.grad[:, rank * S_local : (rank + 1) * S_local] - dv_ref_chunk = v_ref.grad[:, rank * S_local : (rank + 1) * S_local] - - dq_cos = _cosine_error(q_chunk.grad, dq_ref_chunk).item() - dk_cos = _cosine_error(k_chunk.grad, dk_ref_chunk).item() - dv_cos = _cosine_error(v_chunk.grad, dv_ref_chunk).item() - - return fwd_cos, dq_cos, dk_cos, dv_cos - - -def main(): - rank, world_size = setup() - assert world_size >= 2, "Need at least 2 GPUs" - - cp_group = dist.new_group(list(range(world_size))) - atol = 1e-5 - - configs = [ - {"B": 1, "S": 128, "H": 4, "D": 64, "num_kv_heads": 4, "label": "MHA S=128"}, - {"B": 2, "S": 256, "H": 8, "D": 64, "num_kv_heads": 2, "label": "GQA S=256"}, - {"B": 1, "S": 512, "H": 12, "D": 128, "num_kv_heads": 4, "label": "GQA S=512 D=128"}, - ] - - all_passed = True - for cfg in configs: - label = cfg.pop("label") - fwd, dq, dk, dv = test_forward_backward(**cfg, cp_group=cp_group) - worst = max(fwd, dq, dk, dv) - passed = worst < atol - all_passed &= passed - if rank == 0: - status = "PASS" if passed else "FAIL" - print( - f"[{status}] {label}: fwd_cos={fwd:.2e} dQ_cos={dq:.2e} dK_cos={dk:.2e} dV_cos={dv:.2e}" - ) - - dist.destroy_process_group() - if rank == 0: - print("\nAll tests passed." if all_passed else "\nSome tests FAILED.") - sys.exit(0 if all_passed else 1) - - -if __name__ == "__main__": - main() diff --git a/tests/utilities.py b/tests/utilities.py new file mode 100644 index 0000000..d09cced --- /dev/null +++ b/tests/utilities.py @@ -0,0 +1,70 @@ +"""Spawn workers and run them under PithTrain's distributed context.""" + +import os +from collections.abc import Callable +from types import SimpleNamespace + +import pytest +import torch +from torch.multiprocessing.spawn import spawn + +from pithtrain.modules.distributed import DistributedCfg, DistributedCtx, distributed_context + +# Snapshot launcher-provided values at module load. Within a single pytest +# session, multiple launch calls would otherwise pollute each other's env via +# setdefault no-ops when mesh_extent differs between parametrizations. +LAUNCHER_WORLD_SIZE = os.environ.get("WORLD_SIZE") +LAUNCHER_LOCAL_WORLD_SIZE = os.environ.get("LOCAL_WORLD_SIZE") + + +def cosine_error(a: torch.Tensor, b: torch.Tensor) -> float: + """ + Return 1 - cosine similarity: 0 if a and b point the same direction, 2 if + opposite. Scale-invariant, so magnitude shifts in low-precision do not inflate it. + """ + a, b = a.double().flatten(), b.double().flatten() + return float(1.0 - torch.dot(a, b) / (a.norm() * b.norm() + 1e-12)) + + +def entrypoint(i: int, cfg: DistributedCfg, worker: Callable, *args) -> None: + node_rank = int(os.environ["NODE_RANK"]) + world_local_size = int(os.environ["LOCAL_WORLD_SIZE"]) + os.environ["RANK"] = str(node_rank * world_local_size + i) + os.environ["LOCAL_RANK"] = str(i) + + parent_cfg = SimpleNamespace(distributed=cfg) + parent_ctx = SimpleNamespace(distributed=DistributedCtx()) + with distributed_context(parent_cfg, parent_ctx) as ctx: + worker(ctx, *args) + + +def launch(cfg: DistributedCfg, worker: Callable, *args) -> None: + """ + Spawn workers and call worker(ctx, *args) inside each. Skip the test if + the distributed runtime cannot provide pp * cp * ep ranks. + """ + mesh_extent = 1 + mesh_extent *= cfg.pipeline_parallel_size + mesh_extent *= cfg.context_parallel_size + mesh_extent *= cfg.expert_parallel_size + + os.environ["WORLD_SIZE"] = LAUNCHER_WORLD_SIZE or str(mesh_extent) + os.environ["LOCAL_WORLD_SIZE"] = LAUNCHER_LOCAL_WORLD_SIZE or os.environ["WORLD_SIZE"] + os.environ.setdefault("NODE_RANK", "0") + os.environ.setdefault("MASTER_ADDR", "localhost") + os.environ.setdefault("MASTER_PORT", "15213") + os.environ.setdefault("TORCHELASTIC_RUN_ID", "pytest") + + world_size = int(os.environ["WORLD_SIZE"]) + local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) + + if world_size < mesh_extent: + pytest.skip(f"require {mesh_extent} ranks, got {world_size}") + if torch.cuda.device_count() < local_world_size: + pytest.skip(f"require {local_world_size} GPUs, got {torch.cuda.device_count()}") + if world_size % mesh_extent != 0: + raise ValueError(f"{world_size=} not divisible by {mesh_extent=}") + if world_size % local_world_size != 0: + raise ValueError(f"{world_size=} not divisible by {local_world_size=}") + + spawn(entrypoint, args=(cfg, worker, *args), nprocs=local_world_size)