Stage-Separated Profiling with Warm KV Cache Support#9
Merged
TerrenceZhangX merged 13 commits intomicrosoft:mainfrom Mar 17, 2026
Merged
Stage-Separated Profiling with Warm KV Cache Support#9TerrenceZhangX merged 13 commits intomicrosoft:mainfrom
TerrenceZhangX merged 13 commits intomicrosoft:mainfrom
Conversation
Add scripts/run_stage_profile.py — uses SGLang native profile_by_stage API
(with num_steps=1) to automatically collect separate EXTEND (prefill) and
DECODE traces. No sglang patches required.
Features:
- Warmup requests before profiling to avoid CUDA graph capture overhead
- Single-point or sweep mode over (batch_size, context_len) grid
- Optional server launch (all-in-one mode)
- Optional trace parsing via run_parse.py
- Organized output: <output_dir>/bs<N>_ctx<N>/{EXTEND,DECODE}.trace.json.gz
Usage:
python scripts/run_stage_profile.py --port 30001 --bs 1 --ctx 2048
python scripts/run_stage_profile.py --port 30001 --sweep
- Changed send_requests to fire all bs>1 requests concurrently via threading.Thread daemon threads. Previously requests were sequential, causing num_steps=1 to only profile the first request (all bs>1 points were effectively bs=1). - Added stabilization wait after wait_for_traces returns to allow slower DP workers to export their traces (10s file-count-stable window). - Added kernel classification and analyze_traces/print_analysis for automated breakdown of compute vs communication costs.
…black formatting
- scripts/run_stage_profile.py: add --collect {perf,shapes,all} modes,
integrate collect_shapes + merge_shapes, remove legacy flags
- scripts/run_configs.sh: unified bash wrapper for all 4 parallelism configs
(P1-P4) with perf/shapes/all/reanalyze modes
- utils/cross_rank_agg.py: cross-rank kernel aggregation
(sym-min / asym-max comm, mean compute)
- utils/shape_merge.py: merge kernel shape data into timing CSVs
- utils/__init__.py: package marker
- README.md: add Stage Profiling docs, helper scripts table, utils table
- Apply black formatting to all Python files
SGLang's cross_device_reduce_1stage assigns one 'fast rank' (~4us) per invocation while others barrier-wait (~850us). Under profiling noise the fast-rank role rotates, so no single rank is consistently fast. The old min-rank-total approach still picked up inflated values (29-78ms vs normal ~10ms) in P1/P2 EXTEND data. Fix: for each invocation index, take min across ranks, then sum. Falls back to min-rank-total when call counts disagree across ranks. Result: all 240/240 data points clean, P1/P2 EXTEND anomalies resolved.
Add --collect prefill mode to run_stage_profile.py for profiling EXTEND (prefill) over a (input_len, existing_ctx) grid, matching the workload characterization README spec (20-point prefill matrix). Protocol for each (input_len, existing_ctx) point: 1. POST /flush_cache — clear radix + KV cache 2. Seed request with existing_ctx tokens (populates radix cache) 3. POST /start_profile — arm profiler 4. Profile request with existing_ctx + input_len tokens (prefix hits cache, extension triggers real EXTEND of input_len new tokens) Key design choices: - Use input_ids (raw token ID arrays) instead of text prompts to get exact token counts with zero tokenizer error - PREFIX_TOKEN=1000 for cached prefix, EXTEND_TOKEN=2000 for new part - /flush_cache between points ensures no cross-contamination Validated on TP=4 Qwen3-235B-A22B-FP8: - input=256, ctx=0: 18ms EXTEND, AR shape=[256, 4096] (exact) - input=256, ctx=4096: 23ms EXTEND (5ms more from longer KV attention) - input=2048, ctx=0: 78ms EXTEND - input=2048, ctx=4096: 91ms EXTEND Also updates run_configs.sh with INPUT_LEN_GRID and EXISTING_CTX_GRID.
…fill Prefill sweep is now a 3D grid (bs, input_len, existing_ctx): - bs: auto-stops when OOM is detected at a given (input_len, ctx) - Default prefill_bs_grid: 1,2,4,8,16,32,64,128,256,512 - Default input_len_grid: 256,512,1024,2048,4096,8192 - Default existing_ctx_grid: 0,4096,8192,16384,24576 collect_one_prefill now fires bs concurrent requests to measure real multi-request prefill performance (GEMM shape = [bs*input_len, hidden]). OOM detection: - If a request returns OOM error, all larger bs values for the same (input_len, existing_ctx) pair are automatically skipped - After OOM, flush_cache is called to recover --disable-chunked-prefill flag: - Adds --chunked-prefill-size 999999 to server opts - Allows profiling raw kernel saturation without chunk boundaries - Reveals the natural GEMM saturation point for choosing optimal chunk size run_configs.sh updated with PREFILL_BS_GRID and --disable-chunked-prefill.
… validation, tests
- Replace sweep grid (bs-grid/ctx-grid) with single-point --bs/--input-len/--existing-ctx/--decode-tokens API
- Remove dead code: seed_prefix, send_requests, collect_one, _run_prefill, unused imports
- Fix file handle leak in launch_server/kill_server
- Fix attention seqlen formula (per-sequence, not total)
- Add strict shape merge with occurrence mismatch errors and op/operation column copying
- Add None-safe dtype/byte lookups in base_parser.py comm kernel blocks
- Add ValueError handling for float parsing in nccl_benchmarks.py
- Extract wait_for_port to utils/net.py, remove duplicate from run_stage_profile.py
- Relax test_kernel_db_coverage to warn instead of fail on missing op mappings
- Update test_batch_request.py for new fields, mock HTTP calls, base_url fix
- Update README to match single-point API and new output directory format
- Delete obsolete scripts/run_configs.sh
- Add unit tests for cross_rank_agg and shape_merge
- Add integration tests for --collect {perf,shapes,all} modes
- README: fix column name 'Input Dims' → 'Dims' to match actual CSV schema - shape_merge: handle empty timing CSV (guard against IndexError) - shape_merge: warn on duplicate shape CSVs for same (rank, stage) - shape_merge: fix docstring imports to utils.shape_merge - cross_rank_agg: fix docstring imports to utils.cross_rank_agg - run_stage_profile: warmup() use input_ids instead of text prompt - run_stage_profile: parse_traces() check subprocess returncode
added 2 commits
March 17, 2026 02:49
- cross_rank_agg: add reducescatter/reduce_scatter to _COMM_KEYWORDS - cross_rank_agg: fix CLI docstring to use python -m utils.cross_rank_agg - run_stage_profile: use project root on sys.path, import via utils.* - run_stage_profile: collect_shapes checks both EXTEND+DECODE before skip - run_stage_profile: enforce --decode-tokens >= 2 (fencepost guard)
- shape_merge: validate shape CSV has no extra unmatched occurrences - shape_merge: add extrasaction='ignore' to DictWriter for extra columns - run_stage_profile: return ok=False for non-OOM profile request failures
Contributor
There was a problem hiding this comment.
Pull request overview
Adds a stage-separated profiling workflow to FlowSim that profiles SGLang EXTEND (prefill) and DECODE independently (including warm KV-cache scenarios), then post-processes traces into parsed CSVs with cross-rank aggregation and optional kernel-shape merging.
Changes:
- Introduces
scripts/run_stage_profile.pyto orchestrate stage profiling, parsing, cross-rank analysis, and optional shape collection/merging. - Adds utilities for cross-rank kernel aggregation (
utils/cross_rank_agg.py), shape merging (utils/shape_merge.py), and a shared networking helper (utils/net.py). - Expands unit/integration test coverage for the new pipeline and adjusts a few existing tests/parsers for robustness.
Reviewed changes
Copilot reviewed 15 out of 17 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
scripts/run_stage_profile.py |
New stage-separated profiling orchestrator (collect/parse/analyze/shapes/merge). |
utils/cross_rank_agg.py |
New cross-rank aggregation logic with comm-vs-compute semantics. |
utils/shape_merge.py |
New shape-pass → timing-pass CSV merge utility with occurrence matching. |
utils/net.py |
New shared wait_for_port() helper. |
utils/__init__.py |
Marks utils as a package. |
tests/integration/test_stage_profile_configs.py |
Integration coverage for --collect {perf,shapes,all} and workload-shape validation. |
tests/unit/test_cross_rank_agg.py |
Unit tests for kernel classification + aggregation behavior. |
tests/unit/test_shape_merge.py |
Unit tests for single-pair and directory shape merge behavior. |
simulator/base_parser.py |
Makes comm calibration more robust to unknown dtypes/bytes. |
simulator/benchmarks/nccl_benchmarks.py |
Hardens parsing of benchmark output by skipping non-float values. |
tests/unit/test_batch_request.py |
Updates request fields and prevents real HTTP calls in unit tests. |
tests/unit/test_kernel_db_coverage.py |
Downgrades missing-op assertion to a warning. |
tests/unit/test_defined_len.py |
Formatting-only change. |
tests/unit/test_llmcompass_backend.py |
Removes an extraneous blank line. |
scripts/run_simulate.py |
Formatting-only change (black). |
README.md |
Documents stage profiling usage, modes, outputs, and utilities. |
.gitignore |
Ignores __pycache__/. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Add a complete stage-separated profiling pipeline (
run_stage_profile.py) that captures prefill (EXTEND) and decode traces independently from SGLang inference servers. Includes cross-rank kernel aggregation, kernel shape merging, and comprehensive tests.Motivation
1. Decouple prefill and decode to reduce profiling cost
The existing
run_profile.pycaptures prefill+decode in one fused trace, so every (prefill_len, decode_len) combination needs its own run. For N prefill lengths × M decode lengths that's N×M profiling passes. Stage-separated profiling captures them independently — N+M runs instead of N×M — and traces can be freely recombined offline.2. Warm KV cache support for agentic / long-context workloads
In multi-turn chat, tool-use agents, and RAG pipelines, each new request builds on thousands of tokens already in KV cache. The prefill phase only processes new tokens while reading from cached context, producing different kernel shapes and timings than a cold prefill of the same total length. The
--existing-ctxflag simulates this by seeding the radix cache before profiling, enabling accurate traces for warm-cache scenarios.Implementation
This PR implements the full pipeline:
profile_by_stageAPIinput_ids(not approximate text prompts)--existing-ctx) for warm prefill scenariosChanges
New files
scripts/run_stage_profile.py--collect {perf,shapes,all}modesutils/cross_rank_agg.pyutils/shape_merge.pyutils/net.pywait_for_porthelper (extracted from duplicate definitions)utils/__init__.pytests/integration/test_stage_profile_configs.py--collectmodes with shape validationtests/unit/test_cross_rank_agg.pytests/unit/test_shape_merge.pyModified files
README.md.gitignore__pycache__/simulator/base_parser.pypytorch_to_nccl_dtype.get()returns None for unknown dtypes, previously causedTypeErrorsimulator/benchmarks/nccl_benchmarks.pytry/except ValueErroraroundfloat()parsing inrun_nccl_all_reduce_perfandrun_nccl_all_gather_perfscripts/run_simulate.pytests/unit/test_batch_request.pytext_prompt_len,vision_prompt_len,timestamp), mockrequests.getto prevent real HTTP calls, fixbase_urltests/unit/test_kernel_db_coverage.pywarnings.warn— unknown kernels need manualkernels.jsonentries, not test failurestests/unit/test_defined_len.pytests/unit/test_llmcompass_backend.pyKey design decisions
Single-point API (not sweep grid)
The script profiles a single
(bs, input_len, existing_ctx, decode_tokens)point per invocation. Grid sweeps are handled by the caller (shell loops, CI matrix, etc.). This simplifies the code and makes each run self-contained.Exact token control via
input_idsPrompts use raw token IDs (
input_ids) instead of text to ensure exact prefill token counts. Fixed random seeds guarantee deterministic prompts for radix-cache prefix matching between seed and profile requests.Fencepost handling
SGLang's profiler stops when
batch_count > num_steps(not>=). To capture exactlydecode_tokensdecode batches, we passnum_steps = max(1, decode_tokens - 1).Cross-rank aggregation semantics
Shape merge
CUDA graphs hide kernel input shapes. A second pass without CUDA graphs captures shapes, which are merged into timing CSVs by 1:1 kernel-name occurrence matching. Mismatches raise
ValueErrorto catch configuration bugs early.Testing
test_cross_rank_agg.py(13 tests),test_shape_merge.py(10 tests) — run locally withpytesttest_stage_profile_configs.py(4 tests) — require GPU + Docker container with model configs. Filter withRUN_CONFIGS=P1Stats