Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions benchmarks/operators/bench_ring_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""Benchmark forward and backward latency of ring attention."""

import json
import re
import sys
from pathlib import Path
from types import SimpleNamespace

import torch

from pithtrain.modules.distributed import DistributedCfg, DistributedCtx, distributed_context
from pithtrain.operators.ring_attention import ring_attention_func


def parse_scenario(scenario: str) -> tuple[dict, int, int]:
m = re.match(r"^(.+)-cp(\d+)-s(\d+)k$", scenario)
if not m:
raise ValueError(f"invalid scenario '{scenario}', expected <model>-cp<N>-s<N>k")
model = m.group(1)
with open(Path(f"examples/pretrain_language_model/{model}/config.json")) as f:
config = json.load(f)
return config, int(m.group(2)), int(m.group(3)) * 1024


def run(ctx: DistributedCtx, scenario: str, config: dict, cp_size: int, S: int) -> None:
B = 1
WARMUP, NITERS = 25, 100
HQ, HK = config["num_attention_heads"], config["num_key_value_heads"]
D = config["head_dim"]

cp_group = ctx.device_mesh.get_group("cp")
device = torch.cuda.current_device()
softmax_scale = D**-0.5
S_local = S // cp_size
Comment thread
haok1402 marked this conversation as resolved.

torch.manual_seed(42)
kwargs = dict(device=device, dtype=torch.bfloat16)
q = torch.randn(B, S_local, HQ, D, requires_grad=True, **kwargs)
k = torch.randn(B, S_local, HK, D, requires_grad=True, **kwargs)
v = torch.randn(B, S_local, HK, D, requires_grad=True, **kwargs)
grad_out = torch.randn(B, S_local, HQ, D, **kwargs)

def run_once() -> None:
q.grad, k.grad, v.grad = None, None, None
out = ring_attention_func(q, k, v, softmax_scale, cp_group)
out.backward(grad_out)

for _ in range(WARMUP):
run_once()
torch.cuda.synchronize()

# Timed forward/backward, separated by CUDA events.
fwd_total_ms = 0.0
bwd_total_ms = 0.0
for _ in range(NITERS):
q.grad, k.grad, v.grad = None, None, None
fwd_start = torch.cuda.Event(enable_timing=True)
fwd_end = torch.cuda.Event(enable_timing=True)
bwd_end = torch.cuda.Event(enable_timing=True)
Comment thread
haok1402 marked this conversation as resolved.
fwd_start.record()
out = ring_attention_func(q, k, v, softmax_scale, cp_group)
fwd_end.record()
out.backward(grad_out)
bwd_end.record()
torch.cuda.synchronize()
fwd_total_ms += fwd_start.elapsed_time(fwd_end)
bwd_total_ms += fwd_end.elapsed_time(bwd_end)

fwd_avg = fwd_total_ms / NITERS
bwd_avg = bwd_total_ms / NITERS

if ctx.rank == 0:
print(f"{scenario} | fwd: {fwd_avg:7.3f} ms , bwd: {bwd_avg:7.3f} ms", flush=True)
torch.distributed.barrier()

# Nsys profile capture with one iteration.
torch.cuda.synchronize()
torch.cuda.profiler.start()
run_once()
torch.cuda.synchronize()
torch.cuda.profiler.stop()


if __name__ == "__main__":
scenario = sys.argv[1]
Comment thread
haok1402 marked this conversation as resolved.
config, cp_size, S = parse_scenario(scenario)

cfg = DistributedCfg()
cfg.context_parallel_size = cp_size
parent_cfg = SimpleNamespace(distributed=cfg)
parent_ctx = SimpleNamespace(distributed=DistributedCtx())
with distributed_context(parent_cfg, parent_ctx) as ctx:
run(ctx, scenario, config, cp_size, S)
44 changes: 44 additions & 0 deletions benchmarks/operators/bench_ring_attention.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#!/bin/bash
# Benchmark ring attention. Always captures an nsys profile of the final iteration.
#
# Usage:
# bash benchmarks/operators/bench_ring_attention.sh qwen3-30b-a3b-cp4-s32k

set -euo pipefail
export OMP_NUM_THREADS=8

SCRIPT=benchmarks/operators/bench_ring_attention.py
OUTDIR=workspace/benchmarks/operators
mkdir -p $OUTDIR

if [ $# -lt 1 ]; then
echo "Usage: $0 <scenario> (e.g. qwen3-30b-a3b-cp4-s32k)" >&2
exit 1
fi
SCENARIO=$1

# Extract cp_size from the scenario for torchrun's --nproc-per-node.
if [[ ! $SCENARIO =~ -cp([0-9]+)- ]]; then
echo "Scenario '$SCENARIO' missing -cp<N>- segment" >&2
exit 1
fi
NPROC=${BASH_REMATCH[1]}

NSYS_ARGS=()
NSYS_ARGS+=(profile)
NSYS_ARGS+=(--stats=false)
NSYS_ARGS+=(--trace=cuda,osrt,nvtx)
NSYS_ARGS+=(--force-overwrite=true)
NSYS_ARGS+=(--output=$OUTDIR/ring_attention.$SCENARIO)
NSYS_ARGS+=(--cuda-graph-trace=node)
NSYS_ARGS+=(--capture-range=cudaProfilerApi)
NSYS_ARGS+=(--capture-range-end=stop-shutdown)
NSYS_ARGS+=(--delay=0)

TORCHRUN_ARGS=()
TORCHRUN_ARGS+=(--nnodes=1)
TORCHRUN_ARGS+=(--nproc-per-node=$NPROC)
TORCHRUN_ARGS+=(--rdzv-backend=c10d)
TORCHRUN_ARGS+=(--rdzv-endpoint=localhost:15213)

nsys ${NSYS_ARGS[@]} torchrun ${TORCHRUN_ARGS[@]} $SCRIPT $SCENARIO
Comment thread
haok1402 marked this conversation as resolved.
2 changes: 1 addition & 1 deletion pithtrain/models/deepseek_v2_lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from pithtrain.modules.load_balance import MoELoadBalanceLossInjector, MoELoadBalanceLossTracker
from pithtrain.operators.ep_dispatch import moe_ep_prepare_dispatch
from pithtrain.operators.flash_attn_v4 import mla_flash_attn_func
from pithtrain.operators.ring_attention.standard import ring_attention_func
from pithtrain.operators.ring_attention import ring_attention_func
from pithtrain.operators.silu_mul import silu_mul
from pithtrain.operators.token_scatter import (
padded_index_gather,
Expand Down
2 changes: 1 addition & 1 deletion pithtrain/models/qwen3_30b_a3b.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from pithtrain.modules.load_balance import MoELoadBalanceLossInjector, MoELoadBalanceLossTracker
from pithtrain.operators.ep_dispatch import moe_ep_prepare_dispatch
from pithtrain.operators.flash_attn_v4 import flash_attn_func
from pithtrain.operators.ring_attention.standard import ring_attention_func
from pithtrain.operators.ring_attention import ring_attention_func
from pithtrain.operators.silu_mul import silu_mul
from pithtrain.operators.token_scatter import (
padded_index_gather,
Expand Down
Empty file.
Empty file removed tests/operators/mla/__init__.py
Empty file.
89 changes: 89 additions & 0 deletions tests/operators/test_ring_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""Test the correctness of ring attention under context parallelism."""

from dataclasses import dataclass, fields

import pytest
import torch

from pithtrain.modules.distributed import DistributedCfg, DistributedCtx
from pithtrain.operators.flash_attn_v4 import flash_attn_func
from pithtrain.operators.ring_attention import ring_attention_func
from tests.utilities import cosine_error, launch


@dataclass
class Request:
B: int
S: int
HQ: int
HK: int
D: int
atol: float = 1e-5


@dataclass
class Result:
out: torch.Tensor
dq: torch.Tensor
dk: torch.Tensor
dv: torch.Tensor


def record(ctx: DistributedCtx, req: Request) -> tuple[Result, Result]:
"""
Record the forward output and the input gradients dQ, dK, dV for both the
baseline and the implementation. The baseline is flash_attn_func run on the
full sequence with no CP communication, effectively the full sequence at
once; the implementation is ring_attention_func run on this rank's sequence
shard with K/V exchanged around the CP ring.
"""
cp_group = ctx.device_mesh.get_group("cp")
cp_rank, cp_size = cp_group.rank(), cp_group.size()
device = torch.cuda.current_device()
softmax_scale = req.D**-0.5
S_local = req.S // cp_size
Comment thread
haok1402 marked this conversation as resolved.

def shard(x: torch.Tensor) -> torch.Tensor:
return x[:, cp_rank * S_local : (cp_rank + 1) * S_local]

torch.manual_seed(42)
q_full = torch.randn(req.B, req.S, req.HQ, req.D, device=device, dtype=torch.bfloat16)
k_full = torch.randn(req.B, req.S, req.HK, req.D, device=device, dtype=torch.bfloat16)
v_full = torch.randn(req.B, req.S, req.HK, req.D, device=device, dtype=torch.bfloat16)

q_ref = q_full.clone().requires_grad_(True)
k_ref = k_full.clone().requires_grad_(True)
v_ref = v_full.clone().requires_grad_(True)
out_ref = flash_attn_func(q_ref, k_ref, v_ref, softmax_scale, causal=True)
out_ref.sum().backward()
ref = Result(shard(out_ref), shard(q_ref.grad), shard(k_ref.grad), shard(v_ref.grad))

q_imp = shard(q_full).clone().requires_grad_(True)
k_imp = shard(k_full).clone().requires_grad_(True)
v_imp = shard(v_full).clone().requires_grad_(True)
out_imp = ring_attention_func(q_imp, k_imp, v_imp, softmax_scale, cp_group)
out_imp.sum().backward()
imp = Result(out_imp, q_imp.grad, k_imp.grad, v_imp.grad)

return ref, imp


def verify(ctx: DistributedCtx, req: Request) -> None:
ref, imp = record(ctx, req)
for f in fields(ref):
error = cosine_error(getattr(ref, f.name), getattr(imp, f.name))
if error >= req.atol:
raise AssertionError(f"{f.name} diverged: {error=:.2e} >= {req.atol=}")


REQUESTS = []
REQUESTS.append(pytest.param(2, Request(B=1, S=2048, HQ=4, HK=4, D=64), id="CP2-MHA-S2048"))
REQUESTS.append(pytest.param(2, Request(B=2, S=2048, HQ=8, HK=2, D=64), id="CP2-GQA-S2048"))
REQUESTS.append(pytest.param(4, Request(B=1, S=4096, HQ=12, HK=4, D=128), id="CP4-GQA-S4096"))


@pytest.mark.parametrize("cp_size,req", REQUESTS)
def test_ring_attention_vs_dense(cp_size: int, req: Request) -> None:
cfg = DistributedCfg()
cfg.context_parallel_size = cp_size
launch(cfg, verify, req)
42 changes: 0 additions & 42 deletions tests/operators/utilities.py

This file was deleted.

Loading
Loading