diff --git a/README.md b/README.md index f4e8533..7d80c50 100644 --- a/README.md +++ b/README.md @@ -164,6 +164,8 @@ print(f"\nThroughput: {tps:.2f} tok/s") All benchmarks share the same datasets (gsm8k, math500, humaneval, mbpp, mt-bench). Datasets are automatically downloaded and cached as JSONL in `cache/` on first run. +Benchmark over a steady request window, not the first requests after launch. The DFlash draft path needs a few requests to warm up; cold readings can undersell steady-state throughput by a large margin (see #135). The harness warms up automatically (`--num-warmup`, default 8); raise it on slower/larger setups. + **vLLM**: ```bash python -m dflash.benchmark --backend vllm \ diff --git a/dflash/benchmark.py b/dflash/benchmark.py index 7845a5a..2406901 100644 --- a/dflash/benchmark.py +++ b/dflash/benchmark.py @@ -386,7 +386,8 @@ def _run_server(args: argparse.Namespace) -> None: tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) - num_prompts = args.num_prompts + args.concurrency + warmup_n = 0 if args.num_warmup == 0 else max(args.num_warmup, args.concurrency) + num_prompts = args.num_prompts + warmup_n prompts: list[str] = [] for i in range(num_prompts): item = dataset[i % len(dataset)] @@ -429,12 +430,11 @@ def send_one(prompt: str) -> dict: except Exception: print("Warning: /flush_cache failed. Continuing.") - bs = max(args.concurrency, 1) - if len(prompts) > bs: - print(f"[warmup] {bs} requests ...") - with ThreadPoolExecutor(max_workers=bs) as pool: - list(pool.map(send_one, prompts[:bs])) - prompts = prompts[bs:] + if warmup_n > 0: + print(f"[warmup] {warmup_n} requests ...") + with ThreadPoolExecutor(max_workers=max(args.concurrency, 1)) as pool: + list(pool.map(send_one, prompts[:warmup_n])) + prompts = prompts[warmup_n:] print(f"Running benchmark: {args.num_prompts} prompts, concurrency={args.concurrency} ...") start = time.perf_counter() @@ -492,6 +492,15 @@ def main() -> None: parser.add_argument("--base-url", type=str, default="http://127.0.0.1:30000") parser.add_argument("--num-prompts", type=int, default=1024) parser.add_argument("--concurrency", type=int, default=1) + parser.add_argument( + "--num-warmup", + type=int, + default=8, + help=( + "Warmup requests sent before timing (server backends; default: 8). Draft path needs several requests to " + "reach steady state; the documented --concurrency 1 default otherwise warms with only 1 request." + ), + ) parser.add_argument("--top-p", type=float, default=1.0) parser.add_argument("--top-k", type=int, default=1) parser.add_argument("--enable-thinking", action="store_true")