-
Notifications
You must be signed in to change notification settings - Fork 6
Refactor testing framework, add ring attention benchmark #38
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
deedd99
remove legacy tests for mla operators
haok1402 ed79ff5
import ring attention directly ; without that "standard" remark
haok1402 11c47e4
revise how we do the testing ; add the distributed setup for pytest i…
haok1402 d66c48a
add the benchmarking scripts for the ring attention with profile capt…
haok1402 5a72727
cleanup the benchmark scripts for ring attention to capture the nsys …
haok1402 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,93 @@ | ||
| """Benchmark forward and backward latency of ring attention.""" | ||
|
|
||
| import json | ||
| import re | ||
| import sys | ||
| from pathlib import Path | ||
| from types import SimpleNamespace | ||
|
|
||
| import torch | ||
|
|
||
| from pithtrain.modules.distributed import DistributedCfg, DistributedCtx, distributed_context | ||
| from pithtrain.operators.ring_attention import ring_attention_func | ||
|
|
||
|
|
||
| def parse_scenario(scenario: str) -> tuple[dict, int, int]: | ||
| m = re.match(r"^(.+)-cp(\d+)-s(\d+)k$", scenario) | ||
| if not m: | ||
| raise ValueError(f"invalid scenario '{scenario}', expected <model>-cp<N>-s<N>k") | ||
| model = m.group(1) | ||
| with open(Path(f"examples/pretrain_language_model/{model}/config.json")) as f: | ||
| config = json.load(f) | ||
| return config, int(m.group(2)), int(m.group(3)) * 1024 | ||
|
|
||
|
|
||
| def run(ctx: DistributedCtx, scenario: str, config: dict, cp_size: int, S: int) -> None: | ||
| B = 1 | ||
| WARMUP, NITERS = 25, 100 | ||
| HQ, HK = config["num_attention_heads"], config["num_key_value_heads"] | ||
| D = config["head_dim"] | ||
|
|
||
| cp_group = ctx.device_mesh.get_group("cp") | ||
| device = torch.cuda.current_device() | ||
| softmax_scale = D**-0.5 | ||
| S_local = S // cp_size | ||
|
|
||
| torch.manual_seed(42) | ||
| kwargs = dict(device=device, dtype=torch.bfloat16) | ||
| q = torch.randn(B, S_local, HQ, D, requires_grad=True, **kwargs) | ||
| k = torch.randn(B, S_local, HK, D, requires_grad=True, **kwargs) | ||
| v = torch.randn(B, S_local, HK, D, requires_grad=True, **kwargs) | ||
| grad_out = torch.randn(B, S_local, HQ, D, **kwargs) | ||
|
|
||
| def run_once() -> None: | ||
| q.grad, k.grad, v.grad = None, None, None | ||
| out = ring_attention_func(q, k, v, softmax_scale, cp_group) | ||
| out.backward(grad_out) | ||
|
|
||
| for _ in range(WARMUP): | ||
| run_once() | ||
| torch.cuda.synchronize() | ||
|
|
||
| # Timed forward/backward, separated by CUDA events. | ||
| fwd_total_ms = 0.0 | ||
| bwd_total_ms = 0.0 | ||
| for _ in range(NITERS): | ||
| q.grad, k.grad, v.grad = None, None, None | ||
| fwd_start = torch.cuda.Event(enable_timing=True) | ||
| fwd_end = torch.cuda.Event(enable_timing=True) | ||
| bwd_end = torch.cuda.Event(enable_timing=True) | ||
|
haok1402 marked this conversation as resolved.
|
||
| fwd_start.record() | ||
| out = ring_attention_func(q, k, v, softmax_scale, cp_group) | ||
| fwd_end.record() | ||
| out.backward(grad_out) | ||
| bwd_end.record() | ||
| torch.cuda.synchronize() | ||
| fwd_total_ms += fwd_start.elapsed_time(fwd_end) | ||
| bwd_total_ms += fwd_end.elapsed_time(bwd_end) | ||
|
|
||
| fwd_avg = fwd_total_ms / NITERS | ||
| bwd_avg = bwd_total_ms / NITERS | ||
|
|
||
| if ctx.rank == 0: | ||
| print(f"{scenario} | fwd: {fwd_avg:7.3f} ms , bwd: {bwd_avg:7.3f} ms", flush=True) | ||
| torch.distributed.barrier() | ||
|
|
||
| # Nsys profile capture with one iteration. | ||
| torch.cuda.synchronize() | ||
| torch.cuda.profiler.start() | ||
| run_once() | ||
| torch.cuda.synchronize() | ||
| torch.cuda.profiler.stop() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| scenario = sys.argv[1] | ||
|
haok1402 marked this conversation as resolved.
|
||
| config, cp_size, S = parse_scenario(scenario) | ||
|
|
||
| cfg = DistributedCfg() | ||
| cfg.context_parallel_size = cp_size | ||
| parent_cfg = SimpleNamespace(distributed=cfg) | ||
| parent_ctx = SimpleNamespace(distributed=DistributedCtx()) | ||
| with distributed_context(parent_cfg, parent_ctx) as ctx: | ||
| run(ctx, scenario, config, cp_size, S) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,44 @@ | ||
| #!/bin/bash | ||
| # Benchmark ring attention. Always captures an nsys profile of the final iteration. | ||
| # | ||
| # Usage: | ||
| # bash benchmarks/operators/bench_ring_attention.sh qwen3-30b-a3b-cp4-s32k | ||
|
|
||
| set -euo pipefail | ||
| export OMP_NUM_THREADS=8 | ||
|
|
||
| SCRIPT=benchmarks/operators/bench_ring_attention.py | ||
| OUTDIR=workspace/benchmarks/operators | ||
| mkdir -p $OUTDIR | ||
|
|
||
| if [ $# -lt 1 ]; then | ||
| echo "Usage: $0 <scenario> (e.g. qwen3-30b-a3b-cp4-s32k)" >&2 | ||
| exit 1 | ||
| fi | ||
| SCENARIO=$1 | ||
|
|
||
| # Extract cp_size from the scenario for torchrun's --nproc-per-node. | ||
| if [[ ! $SCENARIO =~ -cp([0-9]+)- ]]; then | ||
| echo "Scenario '$SCENARIO' missing -cp<N>- segment" >&2 | ||
| exit 1 | ||
| fi | ||
| NPROC=${BASH_REMATCH[1]} | ||
|
|
||
| NSYS_ARGS=() | ||
| NSYS_ARGS+=(profile) | ||
| NSYS_ARGS+=(--stats=false) | ||
| NSYS_ARGS+=(--trace=cuda,osrt,nvtx) | ||
| NSYS_ARGS+=(--force-overwrite=true) | ||
| NSYS_ARGS+=(--output=$OUTDIR/ring_attention.$SCENARIO) | ||
| NSYS_ARGS+=(--cuda-graph-trace=node) | ||
| NSYS_ARGS+=(--capture-range=cudaProfilerApi) | ||
| NSYS_ARGS+=(--capture-range-end=stop-shutdown) | ||
| NSYS_ARGS+=(--delay=0) | ||
|
|
||
| TORCHRUN_ARGS=() | ||
| TORCHRUN_ARGS+=(--nnodes=1) | ||
| TORCHRUN_ARGS+=(--nproc-per-node=$NPROC) | ||
| TORCHRUN_ARGS+=(--rdzv-backend=c10d) | ||
| TORCHRUN_ARGS+=(--rdzv-endpoint=localhost:15213) | ||
|
|
||
| nsys ${NSYS_ARGS[@]} torchrun ${TORCHRUN_ARGS[@]} $SCRIPT $SCENARIO | ||
|
haok1402 marked this conversation as resolved.
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
Empty file.
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,89 @@ | ||
| """Test the correctness of ring attention under context parallelism.""" | ||
|
|
||
| from dataclasses import dataclass, fields | ||
|
|
||
| import pytest | ||
| import torch | ||
|
|
||
| from pithtrain.modules.distributed import DistributedCfg, DistributedCtx | ||
| from pithtrain.operators.flash_attn_v4 import flash_attn_func | ||
| from pithtrain.operators.ring_attention import ring_attention_func | ||
| from tests.utilities import cosine_error, launch | ||
|
|
||
|
|
||
| @dataclass | ||
| class Request: | ||
| B: int | ||
| S: int | ||
| HQ: int | ||
| HK: int | ||
| D: int | ||
| atol: float = 1e-5 | ||
|
|
||
|
|
||
| @dataclass | ||
| class Result: | ||
| out: torch.Tensor | ||
| dq: torch.Tensor | ||
| dk: torch.Tensor | ||
| dv: torch.Tensor | ||
|
|
||
|
|
||
| def record(ctx: DistributedCtx, req: Request) -> tuple[Result, Result]: | ||
| """ | ||
| Record the forward output and the input gradients dQ, dK, dV for both the | ||
| baseline and the implementation. The baseline is flash_attn_func run on the | ||
| full sequence with no CP communication, effectively the full sequence at | ||
| once; the implementation is ring_attention_func run on this rank's sequence | ||
| shard with K/V exchanged around the CP ring. | ||
| """ | ||
| cp_group = ctx.device_mesh.get_group("cp") | ||
| cp_rank, cp_size = cp_group.rank(), cp_group.size() | ||
| device = torch.cuda.current_device() | ||
| softmax_scale = req.D**-0.5 | ||
| S_local = req.S // cp_size | ||
|
haok1402 marked this conversation as resolved.
|
||
|
|
||
| def shard(x: torch.Tensor) -> torch.Tensor: | ||
| return x[:, cp_rank * S_local : (cp_rank + 1) * S_local] | ||
|
|
||
| torch.manual_seed(42) | ||
| q_full = torch.randn(req.B, req.S, req.HQ, req.D, device=device, dtype=torch.bfloat16) | ||
| k_full = torch.randn(req.B, req.S, req.HK, req.D, device=device, dtype=torch.bfloat16) | ||
| v_full = torch.randn(req.B, req.S, req.HK, req.D, device=device, dtype=torch.bfloat16) | ||
|
|
||
| q_ref = q_full.clone().requires_grad_(True) | ||
| k_ref = k_full.clone().requires_grad_(True) | ||
| v_ref = v_full.clone().requires_grad_(True) | ||
| out_ref = flash_attn_func(q_ref, k_ref, v_ref, softmax_scale, causal=True) | ||
| out_ref.sum().backward() | ||
| ref = Result(shard(out_ref), shard(q_ref.grad), shard(k_ref.grad), shard(v_ref.grad)) | ||
|
|
||
| q_imp = shard(q_full).clone().requires_grad_(True) | ||
| k_imp = shard(k_full).clone().requires_grad_(True) | ||
| v_imp = shard(v_full).clone().requires_grad_(True) | ||
| out_imp = ring_attention_func(q_imp, k_imp, v_imp, softmax_scale, cp_group) | ||
| out_imp.sum().backward() | ||
| imp = Result(out_imp, q_imp.grad, k_imp.grad, v_imp.grad) | ||
|
|
||
| return ref, imp | ||
|
|
||
|
|
||
| def verify(ctx: DistributedCtx, req: Request) -> None: | ||
| ref, imp = record(ctx, req) | ||
| for f in fields(ref): | ||
| error = cosine_error(getattr(ref, f.name), getattr(imp, f.name)) | ||
| if error >= req.atol: | ||
| raise AssertionError(f"{f.name} diverged: {error=:.2e} >= {req.atol=}") | ||
|
|
||
|
|
||
| REQUESTS = [] | ||
| REQUESTS.append(pytest.param(2, Request(B=1, S=2048, HQ=4, HK=4, D=64), id="CP2-MHA-S2048")) | ||
| REQUESTS.append(pytest.param(2, Request(B=2, S=2048, HQ=8, HK=2, D=64), id="CP2-GQA-S2048")) | ||
| REQUESTS.append(pytest.param(4, Request(B=1, S=4096, HQ=12, HK=4, D=128), id="CP4-GQA-S4096")) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("cp_size,req", REQUESTS) | ||
| def test_ring_attention_vs_dense(cp_size: int, req: Request) -> None: | ||
| cfg = DistributedCfg() | ||
| cfg.context_parallel_size = cp_size | ||
| launch(cfg, verify, req) | ||
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.