Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ print(f"\nThroughput: {tps:.2f} tok/s")

All benchmarks share the same datasets (gsm8k, math500, humaneval, mbpp, mt-bench). Datasets are automatically downloaded and cached as JSONL in `cache/` on first run.

Benchmark over a steady request window, not the first requests after launch. The DFlash draft path needs a few requests to warm up; cold readings can undersell steady-state throughput by a large margin (see #135). The harness warms up automatically (`--num-warmup`, default 8); raise it on slower/larger setups.

**vLLM**:
```bash
python -m dflash.benchmark --backend vllm \
Expand Down
23 changes: 16 additions & 7 deletions dflash/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,8 @@ def _run_server(args: argparse.Namespace) -> None:

tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)

num_prompts = args.num_prompts + args.concurrency
warmup_n = 0 if args.num_warmup == 0 else max(args.num_warmup, args.concurrency)
num_prompts = args.num_prompts + warmup_n
prompts: list[str] = []
for i in range(num_prompts):
item = dataset[i % len(dataset)]
Expand Down Expand Up @@ -429,12 +430,11 @@ def send_one(prompt: str) -> dict:
except Exception:
print("Warning: /flush_cache failed. Continuing.")

bs = max(args.concurrency, 1)
if len(prompts) > bs:
print(f"[warmup] {bs} requests ...")
with ThreadPoolExecutor(max_workers=bs) as pool:
list(pool.map(send_one, prompts[:bs]))
prompts = prompts[bs:]
if warmup_n > 0:
print(f"[warmup] {warmup_n} requests ...")
with ThreadPoolExecutor(max_workers=max(args.concurrency, 1)) as pool:
list(pool.map(send_one, prompts[:warmup_n]))
prompts = prompts[warmup_n:]

print(f"Running benchmark: {args.num_prompts} prompts, concurrency={args.concurrency} ...")
start = time.perf_counter()
Expand Down Expand Up @@ -492,6 +492,15 @@ def main() -> None:
parser.add_argument("--base-url", type=str, default="http://127.0.0.1:30000")
parser.add_argument("--num-prompts", type=int, default=1024)
parser.add_argument("--concurrency", type=int, default=1)
parser.add_argument(
"--num-warmup",
type=int,
default=8,
help=(
"Warmup requests sent before timing (server backends; default: 8). Draft path needs several requests to "
"reach steady state; the documented --concurrency 1 default otherwise warms with only 1 request."
),
)
parser.add_argument("--top-p", type=float, default=1.0)
parser.add_argument("--top-k", type=int, default=1)
parser.add_argument("--enable-thinking", action="store_true")
Expand Down