Skip to content

Stage-Separated Profiling with Warm KV Cache Support#9

Merged
TerrenceZhangX merged 13 commits intomicrosoft:mainfrom
TerrenceZhangX:zhangt/stage-profiling
Mar 17, 2026
Merged

Stage-Separated Profiling with Warm KV Cache Support#9
TerrenceZhangX merged 13 commits intomicrosoft:mainfrom
TerrenceZhangX:zhangt/stage-profiling

Conversation

@TerrenceZhangX
Copy link
Contributor

Summary

Add a complete stage-separated profiling pipeline (run_stage_profile.py) that captures prefill (EXTEND) and decode traces independently from SGLang inference servers. Includes cross-rank kernel aggregation, kernel shape merging, and comprehensive tests.

Motivation

1. Decouple prefill and decode to reduce profiling cost

The existing run_profile.py captures prefill+decode in one fused trace, so every (prefill_len, decode_len) combination needs its own run. For N prefill lengths × M decode lengths that's N×M profiling passes. Stage-separated profiling captures them independently — N+M runs instead of N×M — and traces can be freely recombined offline.

2. Warm KV cache support for agentic / long-context workloads

In multi-turn chat, tool-use agents, and RAG pipelines, each new request builds on thousands of tokens already in KV cache. The prefill phase only processes new tokens while reading from cached context, producing different kernel shapes and timings than a cold prefill of the same total length. The --existing-ctx flag simulates this by seeding the radix cache before profiling, enabling accurate traces for warm-cache scenarios.

Implementation

This PR implements the full pipeline:

  • Separates prefill (EXTEND) from decode traces using SGLang's profile_by_stage API
  • Controls exact token counts via input_ids (not approximate text prompts)
  • Supports KV cache context simulation (--existing-ctx) for warm prefill scenarios
  • Aggregates kernel timings across TP/DP ranks with correct semantics per kernel type

Changes

New files

File Lines Purpose
scripts/run_stage_profile.py 925 Main orchestrator: warmup → profile → parse → cross-rank analysis, with --collect {perf,shapes,all} modes
utils/cross_rank_agg.py 491 Cross-rank kernel aggregation (symmetric collectives → per-invocation min, asymmetric → max, compute → mean)
utils/shape_merge.py 365 Merge kernel shape data from no-CUDA-graph pass into timing CSVs with strict 1:1 occurrence matching
utils/net.py 31 Shared wait_for_port helper (extracted from duplicate definitions)
utils/__init__.py 0 Package marker
tests/integration/test_stage_profile_configs.py 518 Integration tests for all 3 --collect modes with shape validation
tests/unit/test_cross_rank_agg.py 248 Unit tests for kernel classification, comm detection, multi-rank aggregation
tests/unit/test_shape_merge.py 322 Unit tests for shape merging (single pair + directory mode)

Modified files

File What changed
README.md +109 lines: Stage Profiling section with quick reference, examples, output structure, utilities table
.gitignore Add __pycache__/
simulator/base_parser.py None-safe dtype/byte lookups in 3 comm kernel blocks — pytorch_to_nccl_dtype.get() returns None for unknown dtypes, previously caused TypeError
simulator/benchmarks/nccl_benchmarks.py try/except ValueError around float() parsing in run_nccl_all_reduce_perf and run_nccl_all_gather_perf
scripts/run_simulate.py Formatting only (black)
tests/unit/test_batch_request.py Add new fields (text_prompt_len, vision_prompt_len, timestamp), mock requests.get to prevent real HTTP calls, fix base_url
tests/unit/test_kernel_db_coverage.py Relax missing-op assertion to warnings.warn — unknown kernels need manual kernels.json entries, not test failures
tests/unit/test_defined_len.py Formatting only (black)
tests/unit/test_llmcompass_backend.py Remove extra blank line

Key design decisions

Single-point API (not sweep grid)

The script profiles a single (bs, input_len, existing_ctx, decode_tokens) point per invocation. Grid sweeps are handled by the caller (shell loops, CI matrix, etc.). This simplifies the code and makes each run self-contained.

Exact token control via input_ids

Prompts use raw token IDs (input_ids) instead of text to ensure exact prefill token counts. Fixed random seeds guarantee deterministic prompts for radix-cache prefix matching between seed and profile requests.

Fencepost handling

SGLang's profiler stops when batch_count > num_steps (not >=). To capture exactly decode_tokens decode batches, we pass num_steps = max(1, decode_tokens - 1).

Cross-rank aggregation semantics

Kernel type Aggregation Rationale
Symmetric collectives (all_reduce, all_gather, reduce_scatter) Per-invocation min across ranks All ranks participate equally; min removes straggler noise
Asymmetric collectives (all_to_all) Max across ranks Different ranks may send/receive different amounts
Compute kernels Mean across ranks Slight per-rank timing variation is noise

Shape merge

CUDA graphs hide kernel input shapes. A second pass without CUDA graphs captures shapes, which are merged into timing CSVs by 1:1 kernel-name occurrence matching. Mismatches raise ValueError to catch configuration bugs early.

Testing

  • Unit tests: test_cross_rank_agg.py (13 tests), test_shape_merge.py (10 tests) — run locally with pytest
  • Integration tests: test_stage_profile_configs.py (4 tests) — require GPU + Docker container with model configs. Filter with RUN_CONFIGS=P1
# Unit tests (no GPU needed)
cd FlowSim && python -m pytest tests/unit/test_cross_rank_agg.py tests/unit/test_shape_merge.py -v

# Integration tests (inside Docker with GPUs)
RUN_CONFIGS=P1 python -m pytest tests/integration/test_stage_profile_configs.py -v --timeout=1800

Stats

17 files changed, 3173 insertions(+), 84 deletions(-)

Tong Zhang and others added 8 commits March 3, 2026 07:55
Add scripts/run_stage_profile.py — uses SGLang native profile_by_stage API
(with num_steps=1) to automatically collect separate EXTEND (prefill) and
DECODE traces. No sglang patches required.

Features:
- Warmup requests before profiling to avoid CUDA graph capture overhead
- Single-point or sweep mode over (batch_size, context_len) grid
- Optional server launch (all-in-one mode)
- Optional trace parsing via run_parse.py
- Organized output: <output_dir>/bs<N>_ctx<N>/{EXTEND,DECODE}.trace.json.gz

Usage:
  python scripts/run_stage_profile.py --port 30001 --bs 1 --ctx 2048
  python scripts/run_stage_profile.py --port 30001 --sweep
- Changed send_requests to fire all bs>1 requests concurrently via
  threading.Thread daemon threads. Previously requests were sequential,
  causing num_steps=1 to only profile the first request (all bs>1 points
  were effectively bs=1).
- Added stabilization wait after wait_for_traces returns to allow slower
  DP workers to export their traces (10s file-count-stable window).
- Added kernel classification and analyze_traces/print_analysis for
  automated breakdown of compute vs communication costs.
…black formatting

- scripts/run_stage_profile.py: add --collect {perf,shapes,all} modes,
  integrate collect_shapes + merge_shapes, remove legacy flags
- scripts/run_configs.sh: unified bash wrapper for all 4 parallelism configs
  (P1-P4) with perf/shapes/all/reanalyze modes
- utils/cross_rank_agg.py: cross-rank kernel aggregation
  (sym-min / asym-max comm, mean compute)
- utils/shape_merge.py: merge kernel shape data into timing CSVs
- utils/__init__.py: package marker
- README.md: add Stage Profiling docs, helper scripts table, utils table
- Apply black formatting to all Python files
SGLang's cross_device_reduce_1stage assigns one 'fast rank' (~4us) per
invocation while others barrier-wait (~850us). Under profiling noise the
fast-rank role rotates, so no single rank is consistently fast. The old
min-rank-total approach still picked up inflated values (29-78ms vs
normal ~10ms) in P1/P2 EXTEND data.

Fix: for each invocation index, take min across ranks, then sum.
Falls back to min-rank-total when call counts disagree across ranks.

Result: all 240/240 data points clean, P1/P2 EXTEND anomalies resolved.
Add --collect prefill mode to run_stage_profile.py for profiling EXTEND
(prefill) over a (input_len, existing_ctx) grid, matching the workload
characterization README spec (20-point prefill matrix).

Protocol for each (input_len, existing_ctx) point:
1. POST /flush_cache — clear radix + KV cache
2. Seed request with existing_ctx tokens (populates radix cache)
3. POST /start_profile — arm profiler
4. Profile request with existing_ctx + input_len tokens (prefix hits
   cache, extension triggers real EXTEND of input_len new tokens)

Key design choices:
- Use input_ids (raw token ID arrays) instead of text prompts to get
  exact token counts with zero tokenizer error
- PREFIX_TOKEN=1000 for cached prefix, EXTEND_TOKEN=2000 for new part
- /flush_cache between points ensures no cross-contamination

Validated on TP=4 Qwen3-235B-A22B-FP8:
- input=256, ctx=0: 18ms EXTEND, AR shape=[256, 4096] (exact)
- input=256, ctx=4096: 23ms EXTEND (5ms more from longer KV attention)
- input=2048, ctx=0: 78ms EXTEND
- input=2048, ctx=4096: 91ms EXTEND

Also updates run_configs.sh with INPUT_LEN_GRID and EXISTING_CTX_GRID.
…fill

Prefill sweep is now a 3D grid (bs, input_len, existing_ctx):
- bs: auto-stops when OOM is detected at a given (input_len, ctx)
- Default prefill_bs_grid: 1,2,4,8,16,32,64,128,256,512
- Default input_len_grid: 256,512,1024,2048,4096,8192
- Default existing_ctx_grid: 0,4096,8192,16384,24576

collect_one_prefill now fires bs concurrent requests to measure
real multi-request prefill performance (GEMM shape = [bs*input_len, hidden]).

OOM detection:
- If a request returns OOM error, all larger bs values for the same
  (input_len, existing_ctx) pair are automatically skipped
- After OOM, flush_cache is called to recover

--disable-chunked-prefill flag:
- Adds --chunked-prefill-size 999999 to server opts
- Allows profiling raw kernel saturation without chunk boundaries
- Reveals the natural GEMM saturation point for choosing optimal chunk size

run_configs.sh updated with PREFILL_BS_GRID and --disable-chunked-prefill.
… validation, tests

- Replace sweep grid (bs-grid/ctx-grid) with single-point --bs/--input-len/--existing-ctx/--decode-tokens API
- Remove dead code: seed_prefix, send_requests, collect_one, _run_prefill, unused imports
- Fix file handle leak in launch_server/kill_server
- Fix attention seqlen formula (per-sequence, not total)
- Add strict shape merge with occurrence mismatch errors and op/operation column copying
- Add None-safe dtype/byte lookups in base_parser.py comm kernel blocks
- Add ValueError handling for float parsing in nccl_benchmarks.py
- Extract wait_for_port to utils/net.py, remove duplicate from run_stage_profile.py
- Relax test_kernel_db_coverage to warn instead of fail on missing op mappings
- Update test_batch_request.py for new fields, mock HTTP calls, base_url fix
- Update README to match single-point API and new output directory format
- Delete obsolete scripts/run_configs.sh
- Add unit tests for cross_rank_agg and shape_merge
- Add integration tests for --collect {perf,shapes,all} modes

This comment was marked as resolved.

- README: fix column name 'Input Dims' → 'Dims' to match actual CSV schema
- shape_merge: handle empty timing CSV (guard against IndexError)
- shape_merge: warn on duplicate shape CSVs for same (rank, stage)
- shape_merge: fix docstring imports to utils.shape_merge
- cross_rank_agg: fix docstring imports to utils.cross_rank_agg
- run_stage_profile: warmup() use input_ids instead of text prompt
- run_stage_profile: parse_traces() check subprocess returncode

This comment was marked as outdated.

This comment was marked as resolved.

Terrence Zhang added 2 commits March 17, 2026 02:49
- cross_rank_agg: add reducescatter/reduce_scatter to _COMM_KEYWORDS
- cross_rank_agg: fix CLI docstring to use python -m utils.cross_rank_agg
- run_stage_profile: use project root on sys.path, import via utils.*
- run_stage_profile: collect_shapes checks both EXTEND+DECODE before skip
- run_stage_profile: enforce --decode-tokens >= 2 (fencepost guard)

This comment was marked as resolved.

- shape_merge: validate shape CSV has no extra unmatched occurrences
- shape_merge: add extrasaction='ignore' to DictWriter for extra columns
- run_stage_profile: return ok=False for non-OOM profile request failures
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a stage-separated profiling workflow to FlowSim that profiles SGLang EXTEND (prefill) and DECODE independently (including warm KV-cache scenarios), then post-processes traces into parsed CSVs with cross-rank aggregation and optional kernel-shape merging.

Changes:

  • Introduces scripts/run_stage_profile.py to orchestrate stage profiling, parsing, cross-rank analysis, and optional shape collection/merging.
  • Adds utilities for cross-rank kernel aggregation (utils/cross_rank_agg.py), shape merging (utils/shape_merge.py), and a shared networking helper (utils/net.py).
  • Expands unit/integration test coverage for the new pipeline and adjusts a few existing tests/parsers for robustness.

Reviewed changes

Copilot reviewed 15 out of 17 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
scripts/run_stage_profile.py New stage-separated profiling orchestrator (collect/parse/analyze/shapes/merge).
utils/cross_rank_agg.py New cross-rank aggregation logic with comm-vs-compute semantics.
utils/shape_merge.py New shape-pass → timing-pass CSV merge utility with occurrence matching.
utils/net.py New shared wait_for_port() helper.
utils/__init__.py Marks utils as a package.
tests/integration/test_stage_profile_configs.py Integration coverage for --collect {perf,shapes,all} and workload-shape validation.
tests/unit/test_cross_rank_agg.py Unit tests for kernel classification + aggregation behavior.
tests/unit/test_shape_merge.py Unit tests for single-pair and directory shape merge behavior.
simulator/base_parser.py Makes comm calibration more robust to unknown dtypes/bytes.
simulator/benchmarks/nccl_benchmarks.py Hardens parsing of benchmark output by skipping non-float values.
tests/unit/test_batch_request.py Updates request fields and prevents real HTTP calls in unit tests.
tests/unit/test_kernel_db_coverage.py Downgrades missing-op assertion to a warning.
tests/unit/test_defined_len.py Formatting-only change.
tests/unit/test_llmcompass_backend.py Removes an extraneous blank line.
scripts/run_simulate.py Formatting-only change (black).
README.md Documents stage profiling usage, modes, outputs, and utilities.
.gitignore Ignores __pycache__/.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

You can also share your feedback on Copilot code review. Take the survey.

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
@TerrenceZhangX TerrenceZhangX merged commit 871d423 into microsoft:main Mar 17, 2026
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants