diff --git a/oink/README.md b/oink/README.md index d21720a..bcb1032 100644 --- a/oink/README.md +++ b/oink/README.md @@ -1,65 +1,61 @@ # KernelAgent-Oink -KernelAgent-Oink is a small **CuTeDSL (CUTLASS DSL) kernel library** for -**NVIDIA Blackwell (SM10x / GB200 / GB300 / B200-class)**, bundled as a lightweight -Python package that can be used standalone or as a **vLLM general plugin**. +KernelAgent-Oink is a lightweight **CuTeDSL (CUTLASS DSL) kernel package** for +NVIDIA Blackwell **SM10x** GPUs. It can be used standalone or loaded as a +**vLLM general plugin**. -At the moment, the vLLM integration exposes the following `torch.library.custom_op` -entrypoints under the `oink::` namespace: +Current custom ops: - `torch.ops.oink.rmsnorm(x, weight, eps) -> Tensor` - `torch.ops.oink.fused_add_rms_norm(x, residual, weight, eps) -> None` (in-place) -The package also includes additional SM100 kernels used by the benchmark suite: -LayerNorm, Softmax (fwd+bwd), and CrossEntropy (fwd+bwd). +The repo also contains benchmark-facing Blackwell kernels for LayerNorm, Softmax, +and CrossEntropy. ## Requirements -- GPU: **SM10x (Blackwell)** for the fast CuTeDSL paths. On other GPUs, Oink falls back to - reference PyTorch implementations for correctness. -- Python dependencies: - - `nvidia-cutlass-dsl` (CuTeDSL) - - `cuda-python` - - `torch` (provided by your environment / vLLM) +- Blackwell GPU for optimized CuTeDSL paths; other GPUs use correctness-first + PyTorch fallbacks. +- `nvidia-cutlass-dsl>=4.4.2` +- `cuda-python` +- `torch` from the surrounding environment / vLLM Recommended env vars: ```bash -export CUTE_DSL_ARCH=sm_100a export PYTORCH_ALLOC_CONF=expandable_segments:True +export CUTE_DSL_ARCH=sm_103a # GB300 / SM103 +# export CUTE_DSL_ARCH=sm_100a # GB200/B200 / SM100 ``` -On **GB300 / SM103**, prefer: - -```bash -export CUTE_DSL_ARCH=sm_103a -``` - -## Install (editable) +## Install From the `KernelAgent` repo root: ```bash pip install -e ./oink +pip install -e "./oink[bench]" # optional benchmark/plot deps ``` -For running the in-repo benchmark suite / plots: +A reproducible GB300 benchmark environment used for the results below: ```bash -pip install -e "./oink[bench]" +conda create -y -n cute python=3.12 +conda run -n cute python -m pip install --upgrade pip setuptools wheel packaging ninja +conda run -n cute python -m pip install --upgrade --index-url https://download.pytorch.org/whl/cu130 torch +conda run -n cute python -m pip install 'nvidia-cutlass-dsl==4.4.2' cuda-python triton matplotlib +conda run -n cute python -m pip install -e './oink[bench]' ``` ## Usage -### vLLM (general plugin) - -1) Enable the plugin: +### vLLM plugin ```bash export VLLM_USE_OINK_RMSNORM=1 ``` -2) Ensure vLLM keeps `rms_norm` as a custom op when using `torch.compile` / CUDA graphs: +When using `torch.compile` / CUDA graphs, keep vLLM RMSNorm as a custom op: ```python from vllm import LLM @@ -72,12 +68,7 @@ llm = LLM( ) ``` -Without `+rms_norm`, Inductor may fuse RMSNorm into larger kernels and neither -vLLM’s CUDA RMSNorm nor Oink will run. - -### Direct PyTorch usage (manual op registration) - -For standalone use (outside vLLM), register the custom ops once: +### Direct PyTorch ```python import kernelagent_oink @@ -92,73 +83,40 @@ y = torch.ops.oink.rmsnorm(x, w, 1e-6) ## Benchmarks -### GB200 / B200 (SM100) benchmark suite - -The repo includes a Quack-style benchmark suite (tables + SVG plots) to compare -Oink against Quack and to reproduce the reported speedups. The pre-generated -plots below were measured on **GB200 / B200-class SM100** systems. - -In short, Oink’s edge comes from lower pointer-path launch overhead plus Blackwell-tuned shape routing for both hot small-`M` and larger RMSNorm rows. - -On the current B200 forward sweep, Oink holds `1.12x` / `1.06x` geomean over Quack for same-dtype weights on the Quack-suite / DSv3 sets, and `1.18x` / `1.06x` for fp32 weights, with worst output rel-L2 `1.45e-5` (Quack `2.01e-5`). - -- How to run + methodology: `oink/benchmarks/README.md` -- Pre-generated plots: `oink/benchmarks/media/` - -
- SM100 BF16: Oink vs Quack (Quack-suite) -
- -
- SM100 BF16: Oink vs Quack (DSv3-like shapes) -
- -### GB300 (SM103) Q/K-norm results +Benchmark details and commands are in [`benchmarks/README.md`](benchmarks/README.md). +Reported numbers are correctness-gated against PyTorch references before timing. -We also benchmarked the real Llama4x-style Q/K-norm workload on **GB300 -(SM103)** using non-contiguous `q` / `k` views produced by `qkv.split()`. This -benchmark reports both the direct CuTeDSL/CUTLASS baseline and the optimized -Oink path for the production strided `[M, N]` views. The CuTeDSL/CUTLASS -baseline here is a **Q/K-norm adaptation** derived from the -[CUTLASS CuTeDSL Blackwell RMSNorm example](https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/rmsnorm.py), -not the example kernel used unchanged. +Current GB300 / SM103 setup: -For roofline context, we also plot the same workload using a dedicated -useful-bandwidth harness: median CUDA-event timing plus a logical IO model of -one read + one write of the fused `[M, N]` tensor. This is the physically -meaningful view for comparing against the measured practical GB300 BF16 stream -roof, whereas the steady-state CUDA-graph replay medians below are better read -as a latency view. +- NVIDIA GB300, capability `(10, 3)`, `CUTE_DSL_ARCH=sm_103a` +- `torch==2.11.0+cu130`, CUDA `13.0` +- `nvidia-cutlass-dsl==4.4.2`, `cuda-python==13.2.0` +- measured BF16 STREAM-like roof: **7.140 TB/s**
- GB300 BF16: Q/K-norm roofline (Oink vs CuTeDSL) + SM103 / GB300 BF16 benchmark summary
-Representative steady-state CUDA-graph replay medians from one GB300 run are -shown below (absolute microseconds may vary slightly run to run, but the -ranking and trend were stable). +Quack-suite BF16 summary (`N=4096`): -- Q path: Oink is roughly **2.4–3.1x faster** than the CuTeDSL baseline on - representative multi-row workloads. -- K path: Oink is roughly **2.0–3.6x faster** on the same sweep. +| op | rows | geomean vs Quack | large-row roofline note | +|---|---:|---:|---| +| RMSNorm fwd, weight=same | 19 | 1.019x | near measured roof on large rows | +| RMSNorm fwd, weight=fp32 | 19 | 1.100x | near measured roof on large rows | +| LayerNorm fwd | 19 | 1.241x | near measured roof on large rows | +| Softmax fwd+bwd | 19 | 1.673x | near measured roof on large rows | +| CrossEntropy fwd+bwd | 19 | 1.635x | mixed memory/SFU behavior | -Takeaways from the GB300 Q/K-norm sweep: +Historical plots remain under `benchmarks/media/`: -- For the user-relevant multi-row workloads, Oink beats the CuTeDSL/CUTLASS - baseline by comfortably more than 20%. -- In the roofline view, Oink gets close to the practical GB300 BF16 streaming - ceiling on the large-row Q/K shapes, while the CuTeDSL baseline stays much - farther from the roof. -- The only cases below 20% are the tiny single-row latency-floor microcases: - Q `M=1` is ~12% faster and K `M=1` is ~6% faster. -- Correctness spot-check from the same harness: - - Q max diff vs eager: `0.03125` - - K max diff vs eager: `0.007812` +- `sm100_*`: historical SM100 / B200 runs. +- `gb300_bf16_qk_norm_oink_vs_cutedsl_roofline.svg`: historical GB300 Q/K-norm + harness, separate from the Quack-suite table above. ## Links | What | Link | |---|---| -| Quack (expert baseline) | https://github.com/Dao-AILab/quack | -| KernelAgent (agentic framework) | https://github.com/meta-pytorch/KernelAgent | -| vLLM PR (Oink RMSNorm integration) | https://github.com/vllm-project/vllm/pull/31828 | +| Quack baseline | https://github.com/Dao-AILab/quack | +| KernelAgent | https://github.com/meta-pytorch/KernelAgent | +| vLLM Oink RMSNorm PR | https://github.com/vllm-project/vllm/pull/31828 | diff --git a/oink/benchmarks/README.md b/oink/benchmarks/README.md index 26f3c07..9685000 100644 --- a/oink/benchmarks/README.md +++ b/oink/benchmarks/README.md @@ -1,30 +1,39 @@ -# SM100 Benchmarks (KernelAgent-Oink vs Quack) +# Blackwell SM10x Benchmarks (KernelAgent-Oink vs Quack) -This folder contains SM10x (GB200 / GB300 / Blackwell) microbenchmarks for the Oink -CuTeDSL kernels vendored into KernelAgent, comparing against Quack’s SM100 -kernels where Quack provides an equivalent API. +This folder contains SM10x (GB200 / GB300 / Blackwell) microbenchmarks for the +Oink CuTeDSL kernels, comparing against Quack’s SM100 kernels where Quack +provides an equivalent API. ## Prereqs - GPU: **SM10x / Blackwell** (`torch.cuda.get_device_capability()[0] == 10`). - Python deps in your environment: - `torch` - - `nvidia-cutlass-dsl` (CuTeDSL) + - `nvidia-cutlass-dsl>=4.4.2` (CuTeDSL) - `cuda-python` - `triton` (only for `triton.testing.do_bench`) - - `quack` (optional; only needed for Oink-vs-Quack comparisons) + - `quack` / `quack-kernels` (optional; only needed for Oink-vs-Quack comparisons) Recommended env vars: ```bash export PYTORCH_ALLOC_CONF=expandable_segments:True -export CUTE_DSL_ARCH=sm_100a +# GB300 / SM103: +export CUTE_DSL_ARCH=sm_103a +# GB200/B200 / SM100 historical runs: +# export CUTE_DSL_ARCH=sm_100a ``` -On **GB300 / SM103**, prefer: +For the pinned GB300 / SM103 benchmark environment used by the current README +numbers: ```bash -export CUTE_DSL_ARCH=sm_103a +conda create -y -n cute python=3.12 +conda run -n cute python -m pip install --upgrade pip setuptools wheel packaging ninja +conda run -n cute python -m pip install --upgrade --index-url https://download.pytorch.org/whl/cu130 torch +conda run -n cute python -m pip install 'nvidia-cutlass-dsl==4.4.2' cuda-python triton matplotlib pytest pytest-cov +conda run -n cute python -m pip install -e '.[bench]' +conda run -n cute python -m pip install 'git+https://github.com/Dao-AILab/quack.git' # optional comparison baseline ``` ## Shape suites @@ -34,21 +43,47 @@ export CUTE_DSL_ARCH=sm_103a - **DeepSeek-V3-like (DSv3)** - RMSNorm / LayerNorm / Softmax: `M ∈ {4096, 16384, 65536}`, `N ∈ {6144, 7168, 8192}` - Cross-entropy: `M ∈ {4096, 16384, 65536}`, `N ∈ {3072, 6144, 8192, 12288}` +- **DeepSeek-V4-Flash norm shapes (DSv4)** from `deepseek-ai/DeepSeek-V4-Flash/inference/model.py` + - hidden-state RMSNorm / LayerNorm: `M ∈ {4096, 16384, 65536}`, `N = 7168` + - q_lora RMSNorm: `M ∈ {4096, 16384, 65536}`, `N = 1536` + - kv latent / per-head RMSNorm: `M ∈ {4096, 16384, 65536}`, `N = 512` ## Correctness gates -By default, each script runs a per-shape `torch.testing.assert_close` check -vs a **pure-PyTorch reference** **before** emitting timing numbers. When Quack -is available for that op/path, the script also validates Quack vs the *same* +By default, each script runs a per-shape `torch.testing.assert_close` check vs a +**pure-PyTorch reference** **before** emitting timing numbers. When Quack is +available for that op/path, the script also validates Quack vs the *same* reference (so speedups can’t come from looser numerics). -Disable with `--skip-verify` only for quick smoke tests. +Disable with `--skip-verify` only for quick smoke tests. Do not use +`--skip-verify` for README or release performance numbers. + +## Roofline reporting + +Most benchmark JSONs include `*_hbm_frac` using `bench_utils.detect_hbm_peak_gbps()`. +That helper is a coarse fallback (`8000 GB/s` for SM10x) so old JSONs can be +compared consistently. For GB300/SM103 published results, use a measured roofline +run instead. + +Current measured GB300 BF16 STREAM-like roof used in the README: + +- **7.140 TB/s** (triad, `BLOCK=2048`, `warps=8`) +- 90% target: **6.426 TB/s** + +Regenerate on the current machine: + +```bash +conda run -n cute bash -lc 'PYTHONNOUSERSITE=1 CUTE_DSL_ARCH=sm_103a \ + python benchmarks/benchmark/benchmark_hbm_roofline_sm100.py --dtype bf16 --op both --gb 1 \ + --json /tmp/oink_sm103_hbm_roofline_bf16_current.json' +``` ## Running benchmarks -All scripts support: +All primary scripts support: -- `--quack-suite` or `--dsv3` (or `--configs MxN,...`) +- `--quack-suite` or `--dsv3` (and `--dsv4` where applicable) +- `--configs MxN,...` - `--dtype {bf16,fp16,fp32}` - `--iters ` and `--warmup-ms ` for kernel-only timing - `--json ` and/or `--csv ` outputs (meta + rows) @@ -59,100 +94,164 @@ Run the full Quack-suite + DSv3 set (Oink vs Quack) and write all JSON artifacts to a timestamped directory: ```bash -python oink/benchmarks/readme/run_sm100_suite.py --dtype bf16 +conda run -n cute bash -lc 'PYTHONNOUSERSITE=1 CUTE_DSL_ARCH=sm_103a \ + python benchmarks/readme/run_sm100_suite.py --dtype bf16' + +# Include DeepSeek-V4-Flash norm workloads: +conda run -n cute bash -lc 'PYTHONNOUSERSITE=1 CUTE_DSL_ARCH=sm_103a \ + python benchmarks/readme/run_sm100_suite.py --dtype bf16 --include-dsv4 \ + --out-dir /tmp/oink_sm103_suite_bf16_current' ``` -Turn the JSON artifacts into Markdown tables (with geomean speedups): +Turn JSON artifacts into Markdown tables (with geomean speedups): ```bash -python oink/benchmarks/readme/summarize_results.py --in-dir /tmp/kernelagent_oink_sm100_suite_ \ - --out /tmp/kernelagent_oink_sm100_suite_summary.md +conda run -n cute bash -lc 'python benchmarks/readme/summarize_results.py \ + --in-dir /tmp/oink_sm103_suite_bf16_current \ + --out /tmp/oink_sm103_suite_bf16_current_summary.md' ``` -### Measured HBM roofline (STREAM-like) - -To contextualize the `*_tbps` numbers as a fraction of a *measured* bandwidth -ceiling (rather than a theoretical spec), run: +Generate SM103 SVGs from current JSONs and measured roofline: ```bash -CUDA_VISIBLE_DEVICES=0 python oink/benchmarks/benchmark/benchmark_hbm_roofline_sm100.py --dtype bf16 --op both --gb 2 \ - --json /tmp/hbm_roofline_sm100_bf16.json +conda run -n cute bash -lc 'python benchmarks/readme/plot_quack_style_svg.py \ + --in-dir /tmp/oink_sm103_suite_bf16_current \ + --suite quack_suite --include-layernorm \ + --roofline-json /tmp/oink_sm103_hbm_roofline_bf16_current.json \ + --arch-label "SM103 / GB300" \ + --out benchmarks/media/sm103_bf16_oink_vs_quack_with_layernorm.svg' + +conda run -n cute bash -lc 'python benchmarks/readme/plot_quack_style_svg.py \ + --in-dir /tmp/oink_sm103_suite_bf16_current \ + --suite dsv3_all --shape-policy first \ + --roofline-json /tmp/oink_sm103_hbm_roofline_bf16_current.json \ + --arch-label "SM103 / GB300" \ + --out benchmarks/media/sm103_bf16_oink_vs_quack_dsv3_all.svg' ``` +The existing `sm100_*` SVGs in `benchmarks/media/` are historical SM100/B200 +plots. Do not use them as GB300 evidence. + ### RMSNorm forward ```bash -python oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py --dtype bf16 --weight-dtype fp32 --quack-suite --iters 200 --warmup-ms 25 \ +python benchmarks/benchmark/benchmark_rmsnorm_sm100.py --dtype bf16 --weight-dtype fp32 --quack-suite --iters 200 --warmup-ms 25 \ --json /tmp/oink_rmsnorm_fwd_quack_suite.json -python oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py --dtype bf16 --weight-dtype fp32 --dsv3 --iters 200 --warmup-ms 25 \ +python benchmarks/benchmark/benchmark_rmsnorm_sm100.py --dtype bf16 --weight-dtype fp32 --dsv3 --iters 200 --warmup-ms 25 \ --json /tmp/oink_rmsnorm_fwd_dsv3.json # vLLM-style inference weights (weight dtype == activation dtype) -python oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py --dtype bf16 --weight-dtype same --quack-suite --iters 200 --warmup-ms 25 \ +python benchmarks/benchmark/benchmark_rmsnorm_sm100.py --dtype bf16 --weight-dtype same --quack-suite --iters 200 --warmup-ms 25 \ --json /tmp/oink_rmsnorm_fwd_quack_suite_wsame.json + +# DeepSeek-V4-Flash norm grid +python benchmarks/benchmark/benchmark_rmsnorm_sm100.py --dtype bf16 --weight-dtype same --dsv4 --iters 200 --warmup-ms 25 \ + --json /tmp/oink_rmsnorm_fwd_dsv4_wsame.json ``` ### Fused Add + RMSNorm (vLLM-style, in-place) -This is a good "roofline case study" kernel (heavy read/write traffic, very little extra math): +This is a good roofline case study kernel (heavy read/write traffic, very little +extra math). Oink exposes an **in-place** fused op that updates `x` and +`residual`. Quack's fused kernel writes separate `out` and `residual_out` +buffers, so the default benchmark baseline (`--quack-baseline kernel_inplace`) +times Quack plus the copies needed to match Oink's in-place semantics. Use +`--quack-baseline kernel` to time only the Quack kernel with preallocated +outputs. ```bash -CUDA_VISIBLE_DEVICES=0 python oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py --dtype bf16 --M 65536 --N 4096 \ - --json /tmp/fused_add_rmsnorm_sm100_bf16.json +# DeepSeek-V3 hidden-size sweep +PYTHONNOUSERSITE=1 CUTE_DSL_ARCH=sm_103a \ + python benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py \ + --dtype bf16 --dsv3 --iters 80 --warmup-ms 15 \ + --quack-baseline kernel_inplace \ + --json /tmp/oink_sm103_fused_add_rmsnorm_dsv3_bf16.json + +# DeepSeek-V4-Flash hidden-state sweep (N=7168) +PYTHONNOUSERSITE=1 CUTE_DSL_ARCH=sm_103a \ + python benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py \ + --dtype bf16 --dsv4 --iters 80 --warmup-ms 15 \ + --quack-baseline kernel_inplace \ + --json /tmp/oink_sm103_fused_add_rmsnorm_dsv4_bf16.json ``` -Note on the Quack baseline: Oink exposes an **in-place** fused op (updates `x` and `residual`). -Quack’s fused kernel produces `out` and `residual_out` out-of-place, so by default the benchmark -times `quack::_rmsnorm_fwd` **plus** two explicit copies (`x.copy_(out)`, `residual.copy_(residual_out)`) -to match the in-place semantics (integration-realistic). Use `--quack-baseline kernel` to time only -the Quack fused kernel with preallocated outputs. +Current GB300 / SM103 BF16 results from correctness-gated runs: + +| suite | rows | speedup vs Quack (min / geomean / max) | +|---|---:|---:| +| DSv3 fused-add RMSNorm | 9 | 2.022x / 2.045x / 2.089x | +| DSv4 fused-add RMSNorm | 3 | 2.030x / 2.192x / 2.521x | + +DSv3 per-shape results: + +| M | N | Oink ms | Quack ms | speedup | Oink TB/s | +|---:|---:|---:|---:|---:|---:| +| 4096 | 6144 | 0.0360 | 0.0727 | 2.022x | 5.598 | +| 4096 | 7168 | 0.0396 | 0.0828 | 2.089x | 5.926 | +| 4096 | 8192 | 0.0479 | 0.0993 | 2.076x | 5.610 | +| 16384 | 6144 | 0.1206 | 0.2463 | 2.043x | 6.678 | +| 16384 | 7168 | 0.1393 | 0.2830 | 2.031x | 6.742 | +| 16384 | 8192 | 0.1574 | 0.3212 | 2.040x | 6.821 | +| 65536 | 6144 | 0.4575 | 0.9285 | 2.030x | 7.041 | +| 65536 | 7168 | 0.5329 | 1.0785 | 2.024x | 7.052 | +| 65536 | 8192 | 0.6077 | 1.2466 | 2.052x | 7.068 | + +DSv4 per-shape results: + +| M | N | Oink ms | Quack ms | speedup | Oink TB/s | +|---:|---:|---:|---:|---:|---:| +| 4096 | 7168 | 0.0415 | 0.1047 | 2.521x | 5.655 | +| 16384 | 7168 | 0.1388 | 0.2855 | 2.057x | 6.769 | +| 65536 | 7168 | 0.5314 | 1.0785 | 2.030x | 7.072 | ### RMSNorm backward ```bash -python oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py --dtype bf16 --weight-dtype fp32 --quack-suite --iters 100 --warmup-ms 25 \ +python benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py --dtype bf16 --weight-dtype fp32 --quack-suite --iters 100 --warmup-ms 25 \ --csv /tmp/oink_rmsnorm_bwd_quack_suite.csv -python oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py --dtype bf16 --weight-dtype fp32 --dsv3 --iters 100 --warmup-ms 25 \ +python benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py --dtype bf16 --weight-dtype fp32 --dsv3 --iters 100 --warmup-ms 25 \ --csv /tmp/oink_rmsnorm_bwd_dsv3.csv ``` ### Softmax (forward + backward) ```bash -python oink/benchmarks/benchmark/benchmark_softmax_sm100.py --dtype bf16 --mode fwd_bwd --quack-suite --iters 50 --warmup-ms 25 \ +python benchmarks/benchmark/benchmark_softmax_sm100.py --dtype bf16 --mode fwd_bwd --quack-suite --iters 50 --warmup-ms 25 \ --json /tmp/oink_softmax_fwd_bwd_quack_suite.json -python oink/benchmarks/benchmark/benchmark_softmax_sm100.py --dtype bf16 --mode fwd_bwd --dsv3 --iters 50 --warmup-ms 25 \ +python benchmarks/benchmark/benchmark_softmax_sm100.py --dtype bf16 --mode fwd_bwd --dsv3 --iters 50 --warmup-ms 25 \ --json /tmp/oink_softmax_fwd_bwd_dsv3.json ``` ### Cross-entropy (forward + backward) ```bash -python oink/benchmarks/benchmark/benchmark_cross_entropy_sm100.py --dtype bf16 --mode fwd_bwd --quack-suite --iters 50 --warmup-ms 25 \ +python benchmarks/benchmark/benchmark_cross_entropy_sm100.py --dtype bf16 --mode fwd_bwd --quack-suite --iters 50 --warmup-ms 25 \ --json /tmp/oink_cross_entropy_fwd_bwd_quack_suite.json -python oink/benchmarks/benchmark/benchmark_cross_entropy_sm100.py --dtype bf16 --mode fwd_bwd --dsv3 --iters 50 --warmup-ms 25 \ +python benchmarks/benchmark/benchmark_cross_entropy_sm100.py --dtype bf16 --mode fwd_bwd --dsv3 --iters 50 --warmup-ms 25 \ --json /tmp/oink_cross_entropy_fwd_bwd_dsv3.json ``` ### LayerNorm forward ```bash -python oink/benchmarks/benchmark/benchmark_layernorm_sm100.py --dtype bf16 --quack-suite --iters 200 --warmup-ms 25 \ +python benchmarks/benchmark/benchmark_layernorm_sm100.py --dtype bf16 --quack-suite --iters 200 --warmup-ms 25 \ --json /tmp/oink_layernorm_fwd_quack_suite.json -python oink/benchmarks/benchmark/benchmark_layernorm_sm100.py --dtype bf16 --dsv3 --iters 200 --warmup-ms 25 \ +python benchmarks/benchmark/benchmark_layernorm_sm100.py --dtype bf16 --dsv3 --iters 200 --warmup-ms 25 \ --json /tmp/oink_layernorm_fwd_dsv3.json ``` ## Notes - These scripts intentionally avoid importing any external Oink checkout so the - results reflect the in-tree KernelAgent Oink kernels. -- For RMSNorm, the `rmsnorm_with_stage2` implementation is a **fallback** that - is only used when the pointer-based fast path cannot be used (e.g. when - `weight.dtype != x.dtype`, or when layouts/alignments are incompatible). You + results reflect the in-tree KernelAgent-Oink kernels. +- `src/kernelagent_oink/blackwell/rmsnorm_with_stage2.py` is a compatibility + facade. The stage-2 scheduling policy lives in `_rmsnorm_impl.py`; keep the + facade for downstream imports. +- For RMSNorm, the stage-2 path is a fallback used when the pointer-based fast + path cannot be used (for example when layouts/alignments are incompatible). You can force it for A/B testing via `KERNELAGENT_OINK_FORCE_RMSNORM_STAGE2=1`. diff --git a/oink/benchmarks/benchmark/bench_utils.py b/oink/benchmarks/benchmark/bench_utils.py index dd1c2d6..6ba7a3f 100644 --- a/oink/benchmarks/benchmark/bench_utils.py +++ b/oink/benchmarks/benchmark/bench_utils.py @@ -93,7 +93,12 @@ def ensure_blackwell_arch_env(device: Optional[torch.device] = None) -> str: def detect_hbm_peak_gbps(device: Optional[torch.device] = None) -> float: - """Approximate HBM peak bandwidth in GB/s for roofline fractions.""" + """Return a coarse fallback HBM peak in GB/s for benchmark JSON fields. + + This helper is intentionally approximate. For published GB300/SM103 + roofline reporting, prefer a measured roofline JSON from + ``benchmark_hbm_roofline_sm100.py`` and compute fractions against that run. + """ if device is None: device = torch.device("cuda") props = torch.cuda.get_device_properties(device) @@ -144,6 +149,25 @@ def quack_suite_configs() -> List[Tuple[int, int, int]]: return cfgs +def dsv4_norm_configs() -> List[Tuple[int, int]]: + """Return DeepSeek-V4-Flash norm shapes from `inference/model.py`. + + Source dimensions: + - hidden-state norm: N=7168 + - q_lora norm: N=1536 + - kv latent / per-head norm: N=512 + """ + Ms = [4096, 16384, 65536] + Ns = [7168, 1536, 512] + return [(m, n) for n in Ns for m in Ms] + + +def dsv4_hidden_norm_configs() -> List[Tuple[int, int]]: + """Return DeepSeek-V4-Flash hidden-state norm shapes (N=7168).""" + Ms = [4096, 16384, 65536] + return [(m, 7168) for m in Ms] + + def ensure_oink_src_on_path() -> None: """Make the in-repo KernelAgent Oink package importable without an editable install.""" here = os.path.dirname(os.path.abspath(__file__)) diff --git a/oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py b/oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py index e9e5b22..148c8e5 100644 --- a/oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py +++ b/oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py @@ -58,6 +58,7 @@ collect_device_meta, detect_hbm_peak_gbps, do_bench_triton, + dsv4_hidden_norm_configs, ensure_blackwell_arch_env, error_stats_to_row, ensure_oink_src_on_path, @@ -310,6 +311,11 @@ def main() -> None: action="store_true", help="Run DSv3 set: M in {4096,16384,65536}, N in {6144,7168,8192}", ) + p.add_argument( + "--dsv4", + action="store_true", + help="Run DSv4 hidden-state fused-add RMSNorm set: M in {4096,16384,65536}, N=7168", + ) p.add_argument("--warmup-ms", type=int, default=25) p.add_argument( "--iters", type=int, default=200, help="rep_ms for do_bench (default: 200)" @@ -333,7 +339,12 @@ def main() -> None: dtype = parse_dtype(args.dtype) meta = collect_device_meta(torch.device("cuda")) - cfgs = dsv3_configs() if bool(args.dsv3) else [(int(args.M), int(args.N))] + if bool(args.dsv3): + cfgs = dsv3_configs() + elif bool(args.dsv4): + cfgs = dsv4_hidden_norm_configs() + else: + cfgs = [(int(args.M), int(args.N))] rows: List[Dict[str, Any]] = [] for M, N in cfgs: print( diff --git a/oink/benchmarks/benchmark/benchmark_layernorm_sm100.py b/oink/benchmarks/benchmark/benchmark_layernorm_sm100.py index 7ad7f77..569d4c7 100644 --- a/oink/benchmarks/benchmark/benchmark_layernorm_sm100.py +++ b/oink/benchmarks/benchmark/benchmark_layernorm_sm100.py @@ -28,6 +28,7 @@ collect_device_meta, detect_hbm_peak_gbps, do_bench_triton, + dsv4_hidden_norm_configs, ensure_blackwell_arch_env, error_stats_to_row, ensure_oink_src_on_path, @@ -355,6 +356,11 @@ def main() -> None: action="store_true", help="Run DSv3 set: M in {4096,16384,65536}, N in {6144,7168,8192}", ) + p.add_argument( + "--dsv4", + action="store_true", + help="Run DSv4 hidden-state LayerNorm set: M in {4096,16384,65536}, N=7168", + ) p.add_argument( "--skip-verify", action="store_true", @@ -369,6 +375,8 @@ def main() -> None: cfgs = [(bs * sl, hidden) for (bs, sl, hidden) in quack_suite_configs()] elif args.dsv3: cfgs = dsv3_configs() + elif args.dsv4: + cfgs = dsv4_hidden_norm_configs() else: cfgs = parse_configs(args.configs) diff --git a/oink/benchmarks/benchmark/benchmark_paulius_rmsnorm.py b/oink/benchmarks/benchmark/benchmark_paulius_rmsnorm.py deleted file mode 100644 index b2bf75c..0000000 --- a/oink/benchmarks/benchmark/benchmark_paulius_rmsnorm.py +++ /dev/null @@ -1,242 +0,0 @@ -from __future__ import annotations - -import argparse -import os -import re -import subprocess -from pathlib import Path -from typing import Any, Dict, List, Tuple - -import torch - -from bench_utils import collect_device_meta, detect_hbm_peak_gbps, write_csv, write_json - - -def _bench_oink_smallm_noweight(M: int, N: int) -> float: - import sys - - from triton.testing import do_bench_cudagraph - - repo_src = Path(__file__).resolve().parents[2] / "src" - if str(repo_src) not in sys.path: - sys.path.insert(0, str(repo_src)) - from kernelagent_oink.blackwell import _rmsnorm_impl as impl - - x = torch.randn(M, N, device="cuda", dtype=torch.bfloat16) - out = torch.empty_like(x) - return float( - do_bench_cudagraph( - lambda: impl._rmsnorm_forward_ptr_into( - x, None, None, None, out, None, None, 1e-6 - ), - rep=100, - return_mode="mean", - ) - ) - - -def bytes_io_model_fwd(M: int, N: int, dtype: torch.dtype) -> int: - elem = torch.tensor(0, dtype=dtype).element_size() - return int(2 * M * N * elem) - - -def _cuda_13_nvcc() -> Path: - nvcc = Path("/usr/local/cuda-13.0/bin/nvcc") - if not nvcc.is_file(): - raise FileNotFoundError(f"CUDA 13.0 nvcc not found at {nvcc}") - return nvcc - - -def _build_paulius_binary(src_dir: Path) -> Path: - nvcc = _cuda_13_nvcc() - out = src_dir / "r.out" - cmd = [ - str(nvcc), - "-arch=sm_100", - "-Xptxas", - "-v", - "-O3", - "RmsNorm.cu", - "-I../../../", - "-o", - str(out), - "-lnvidia-ml", - ] - env = os.environ.copy() - env["CUDA_HOME"] = "/usr/local/cuda-13.0" - env["PATH"] = f"/usr/local/cuda-13.0/bin:{env.get('PATH', '')}" - subprocess.run( - cmd, - cwd=src_dir, - env=env, - check=True, - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL, - ) - return out - - -def _parse_paulius_output(text: str) -> List[Tuple[float, float]]: - rows: List[Tuple[float, float]] = [] - pattern = re.compile(r"BF16\s+\d+:\s+([0-9.eE+-]+)\s+ms\s+([0-9.eE+-]+)\s+GB/s") - for line in text.splitlines(): - match = pattern.search(line) - if match is None: - continue - rows.append((float(match.group(1)), float(match.group(2)))) - return rows - - -def _run_paulius( - binary: Path, - *, - M: int, - N: int, - cta_dim_y: int, - warmup_reps: int, - timing_reps: int, - gpu_id: int, -) -> Tuple[float, float, Dict[str, Any]]: - cmd = [ - str(binary), - str(M), - str(N), - str(cta_dim_y), - str(warmup_reps), - str(timing_reps), - str(gpu_id), - "0", - "5", - "1", - ] - proc = subprocess.run( - cmd, - cwd=binary.parent, - text=True, - capture_output=True, - check=True, - ) - parsed = _parse_paulius_output(proc.stdout) - if not parsed: - raise RuntimeError( - f"Failed to parse Paulius output:\n{proc.stdout}\n{proc.stderr}" - ) - ms, gbps = min(parsed, key=lambda row: row[0]) - return ms, gbps, {"raw_stdout": proc.stdout, "raw_stderr": proc.stderr} - - -def main() -> None: - if not torch.cuda.is_available(): - raise SystemExit("CUDA not available") - - torch.cuda.set_device(0) - device = torch.device("cuda") - props = torch.cuda.get_device_properties(device) - sm = props.major * 10 + props.minor - print(f"Running on {torch.cuda.get_device_name(device)} (SM{sm})") - - p = argparse.ArgumentParser() - p.add_argument( - "--paulius-dir", - type=str, - default=os.path.expanduser("~/fbsource/fbcode/scripts/paulius/rmsnorm"), - ) - p.add_argument("--gpu-id", type=int, default=0) - p.add_argument("--warmup-reps", type=int, default=10) - p.add_argument("--timing-reps", type=int, default=100) - p.add_argument("--configs", type=str, default="4096x4096,65536x4096") - p.add_argument("--csv", type=str, default=None) - p.add_argument("--json", type=str, default=None) - args = p.parse_args() - - src_dir = Path(args.paulius_dir) - binary = _build_paulius_binary(src_dir) - - cfgs: List[Tuple[int, int]] = [] - for part in args.configs.split(","): - m, n = part.lower().split("x") - cfgs.append((int(m), int(n))) - - meta = collect_device_meta(device) - hbm_peak = detect_hbm_peak_gbps(device) - rows_out: List[Dict[str, Any]] = [] - for M, N in cfgs: - if N != 4096: - raise SystemExit("Paulius benchmark only supports N=4096") - best_ms = float("inf") - best_gbps = 0.0 - best_cta_dim_y = -1 - debug_runs: List[Dict[str, Any]] = [] - for cta_dim_y in (1, 2, 4, 8): - ms, gbps, debug = _run_paulius( - binary, - M=M, - N=N, - cta_dim_y=cta_dim_y, - warmup_reps=int(args.warmup_reps), - timing_reps=int(args.timing_reps), - gpu_id=int(args.gpu_id), - ) - debug_runs.append({"cta_dim_y": cta_dim_y, "ms": ms, "gbps": gbps}) - if ms < best_ms: - best_ms = ms - best_gbps = gbps - best_cta_dim_y = cta_dim_y - row: Dict[str, Any] = { - "M": M, - "N": N, - "dtype": "bf16", - "paulius_ms": best_ms, - "paulius_gbps": best_gbps, - "paulius_tbps": best_gbps / 1000.0, - "paulius_hbm_frac": best_gbps / hbm_peak, - "best_cta_dim_y": best_cta_dim_y, - "io_model_bytes": bytes_io_model_fwd(M, N, torch.bfloat16), - "cta_dim_y_candidates": debug_runs, - } - if M == 4096: - oink_ms = _bench_oink_smallm_noweight(M, N) - oink_gbps = ( - bytes_io_model_fwd(M, N, torch.bfloat16) / (oink_ms * 1e-3) / 1e9 - ) - row.update( - { - "oink_kernel_ms": oink_ms, - "oink_kernel_tbps": oink_gbps / 1000.0, - "oink_speedup_vs_paulius": best_ms / oink_ms, - } - ) - rows_out.append(row) - - if args.csv is not None: - write_csv(args.csv, rows_out) - if args.json is not None: - write_json( - args.json, - meta, - rows_out, - extra={ - "method": "Paulius CUDA benchmark binary", - "warmup_reps": int(args.warmup_reps), - "timing_reps": int(args.timing_reps), - "paulius_dir": str(src_dir), - }, - ) - - print("\nSummary:") - print( - f"{'M':>14} {'N':>14} {'paulius_ms':>14} {'paulius_tbps':>14}" - f" {'ctaDimY':>14} {'oink_ms':>14} {'oink/paulius':>14}" - ) - for r in rows_out: - oink_ms = float(r.get("oink_kernel_ms", float("nan"))) - speedup = float(r.get("oink_speedup_vs_paulius", float("nan"))) - print( - f"{int(r['M']):>14} {int(r['N']):>14} {float(r['paulius_ms']):14.4f}" - f" {float(r['paulius_tbps']):14.4f} {int(r['best_cta_dim_y']):>14}" - f" {oink_ms:14.4f} {speedup:14.4f}" - ) - - -if __name__ == "__main__": - main() diff --git a/oink/benchmarks/benchmark/benchmark_rmsnorm_all.py b/oink/benchmarks/benchmark/benchmark_rmsnorm_all.py deleted file mode 100644 index 6d6eb3d..0000000 --- a/oink/benchmarks/benchmark/benchmark_rmsnorm_all.py +++ /dev/null @@ -1,325 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Benchmark aten vs quack vs oink RMSNorm: normal dispatch + CUDA graph. - -All calls go through ``torch.ops.aten._fused_rms_norm``. -Quack is registered via ``torch._native`` (quack PR pattern). -Oink is registered via ``kernelagent_oink.register_all_kernels()``. - -Produces four tables: - - Forward (normal dispatch) - - Forward + Backward (normal dispatch) - - Forward (CUDA graph) - - Forward + Backward (CUDA graph) - -Usage:: - - python oink/benchmarks/benchmark/benchmark_rmsnorm_all.py -""" - -from __future__ import annotations - -import json -import os -import subprocess -import sys -import tempfile - -os.environ.setdefault("TORCH_NATIVE_SKIP_VERSION_CHECK", "1") - - -# --------------------------------------------------------------------------- -# Worker code: runs in a subprocess per mode to avoid cross-contamination. -# --------------------------------------------------------------------------- - -WORKER_CODE = r""" -import json, os, sys -os.environ.setdefault("TORCH_NATIVE_SKIP_VERSION_CHECK", "1") - -import torch -from triton.testing import do_bench - -DTYPE = torch.bfloat16 - -def bench_normal(fn, warmup=50, rep=200): - return do_bench(fn, warmup=warmup, rep=rep, return_mode="median") - -def bench_cudagraph(fn, warmup=50, rep=200): - for _ in range(warmup): - fn() - torch.cuda.synchronize() - g = torch.cuda.CUDAGraph() - with torch.cuda.graph(g): - fn() - torch.cuda.synchronize() - return do_bench(lambda: g.replay(), warmup=10, rep=rep, return_mode="median") - -mode = sys.argv[1] -shapes_json = sys.argv[2] -SHAPES = json.loads(shapes_json) - -if mode == "oink": - import kernelagent_oink - kernelagent_oink.register_all_kernels(force=True) - -# Warm up -for M, N in SHAPES: - x = torch.randn(M, N, dtype=DTYPE, device="cuda") - w = torch.randn(N, dtype=DTYPE, device="cuda") - torch.ops.aten._fused_rms_norm(x, [N], w, 1e-5) -torch.cuda.synchronize() - -results = {} -for M, N in SHAPES: - x = torch.randn(M, N, dtype=DTYPE, device="cuda", requires_grad=True) - w = torch.randn(N, dtype=DTYPE, device="cuda", requires_grad=True) - grad = torch.randn(M, N, dtype=DTYPE, device="cuda") - - # Forward (normal) - def fn_fwd(x=x, w=w, N=N): - return torch.ops.aten._fused_rms_norm(x, [N], w, 1e-5) - fwd_ms = bench_normal(fn_fwd) - - # Forward + Backward (normal) - x_ = x.detach().requires_grad_(True) - w_ = w.detach().requires_grad_(True) - def fn_fwdbwd(x_=x_, w_=w_, N=N, grad=grad): - y, _ = torch.ops.aten._fused_rms_norm(x_, [N], w_, 1e-5) - y.backward(grad) - fwdbwd_ms = bench_normal(fn_fwdbwd) - - # Forward (CUDA graph) - x_g = torch.randn(M, N, dtype=DTYPE, device="cuda") - w_g = torch.randn(N, dtype=DTYPE, device="cuda") - def fn_fwd_g(x=x_g, w=w_g, N=N): - return torch.ops.aten._fused_rms_norm(x, [N], w, 1e-5) - try: - fwd_graph_ms = bench_cudagraph(fn_fwd_g) - except Exception: - fwd_graph_ms = -1.0 - - # Forward + Backward (CUDA graph) - x_gb = torch.randn(M, N, dtype=DTYPE, device="cuda", requires_grad=True) - w_gb = torch.randn(N, dtype=DTYPE, device="cuda", requires_grad=True) - grad_gb = torch.randn(M, N, dtype=DTYPE, device="cuda") - def fn_fwdbwd_g(x=x_gb, w=w_gb, N=N, grad=grad_gb): - y, _ = torch.ops.aten._fused_rms_norm(x, [N], w, 1e-5) - y.backward(grad) - try: - fwdbwd_graph_ms = bench_cudagraph(fn_fwdbwd_g) - except Exception: - fwdbwd_graph_ms = -1.0 - - results[f"{M}x{N}"] = { - "fwd": fwd_ms, - "fwdbwd": fwdbwd_ms, - "fwd_graph": fwd_graph_ms, - "fwdbwd_graph": fwdbwd_graph_ms, - } - -print(json.dumps({"mode": mode, "results": results})) -""" - - -# --------------------------------------------------------------------------- -# Main: orchestrates subprocesses and prints tables. -# --------------------------------------------------------------------------- - -SHAPES = [ - [1, 4096], - [1, 8192], - [32, 4096], - [32, 8192], - [256, 4096], - [256, 8192], - [1024, 4096], - [1024, 8192], - [4096, 4096], - [4096, 8192], - [16384, 4096], - [16384, 8192], - [65536, 4096], - [65536, 8192], -] - -COL_W = { # column widths - "shape": 14, - "ms": 10, - "ratio": 8, -} - - -def find_norm_dir(): - import torch - from pathlib import Path - - d = Path(torch.__file__).parent / "_native" / "ops" / "norm" - return str(d) if d.is_dir() else None - - -def run_mode(mode, norm_dir, shapes): - init_file = os.path.join(norm_dir, "__init__.py") - - if mode in ("aten", "oink"): - with open(init_file, "w") as f: - f.write("") - elif mode == "quack": - with open(init_file, "w") as f: - f.write("from . import rmsnorm_impl # noqa: F401\n") - - with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as tmp: - tmp.write(WORKER_CODE) - tmp_path = tmp.name - - try: - result = subprocess.run( - [sys.executable, tmp_path, mode, json.dumps(shapes)], - capture_output=True, - text=True, - timeout=600, - ) - if result.returncode != 0: - print(f" [{mode}] FAILED: {result.stderr[-300:]}", file=sys.stderr) - return None - return json.loads(result.stdout.strip())["results"] - finally: - os.unlink(tmp_path) - - -def _fmt_ms(v): - return f"{v:>{COL_W['ms']}.4f}" if v > 0 else "FAIL".rjust(COL_W["ms"]) - - -def _fmt_ratio(n, d): - if d <= 0 or n <= 0: - return "N/A".rjust(COL_W["ratio"]) - return f"{f'{n / d:.2f}x':>{COL_W['ratio']}}" - - -def print_table(title, subtitle, aten, quack, oink, key): - sw, mw, rw = COL_W["shape"], COL_W["ms"], COL_W["ratio"] - w = [sw, mw, mw, mw, rw, rw, rw] - - def hr(left, mid, right): - return left + mid.join("─" * (c + 2) for c in w) + right - - hdr = ( - f"│ {'Shape (M,N)':^{sw}} " - f"│ {'Aten (ms)':^{mw}} " - f"│ {'Quack (ms)':^{mw}} " - f"│ {'Oink (ms)':^{mw}} " - f"│ {'Q/A':^{rw}} " - f"│ {'O/A':^{rw}} " - f"│ {'O/Q':^{rw}} │" - ) - - print() - print(f" {title}") - print(f" {subtitle}") - print(hr("┌", "┬", "┐")) - print(hdr) - print(hr("├", "┼", "┤")) - - for shape in aten: - M, N = shape.split("x") - a, q, o = aten[shape][key], quack[shape][key], oink[shape][key] - row = ( - f"│ {f'({M},{N})':>{sw}} " - f"│ {_fmt_ms(a)} " - f"│ {_fmt_ms(q)} " - f"│ {_fmt_ms(o)} " - f"│ {_fmt_ratio(a, q)} " - f"│ {_fmt_ratio(a, o)} " - f"│ {_fmt_ratio(q, o)} │" - ) - print(row) - - print(hr("└", "┴", "┘")) - - -def main(): - import torch - - print("=" * 72) - print(" RMSNorm Kernel Benchmark: Aten vs Quack vs Oink") - print("=" * 72) - print(f" Device : {torch.cuda.get_device_name(0)}") - print(f" Torch : {torch.__version__}") - print(" Dtype : bfloat16") - print(" Quack : registered via torch._native (quack PR)") - print(" Oink : registered via kernelagent_oink.register_all_kernels()") - print(" Bench : triton.testing.do_bench (median, 200 reps)") - - norm_dir = find_norm_dir() - if norm_dir is None: - print("ERROR: torch._native/ops/norm/ not found.", file=sys.stderr) - sys.exit(1) - - print() - print("Running aten...") - aten = run_mode("aten", norm_dir, SHAPES) - print("Running quack...") - quack = run_mode("quack", norm_dir, SHAPES) - print("Running oink...") - oink = run_mode("oink", norm_dir, SHAPES) - - # Restore - with open(os.path.join(norm_dir, "__init__.py"), "w") as f: - f.write("from . import rmsnorm_impl # noqa: F401\n") - - if not all([aten, quack, oink]): - print("ERROR: one or more modes failed.", file=sys.stderr) - sys.exit(1) - - print_table( - "Forward — Normal Dispatch", - "Standard Python dispatch through torch.ops.aten._fused_rms_norm.", - aten, - quack, - oink, - "fwd", - ) - print_table( - "Forward + Backward — Normal Dispatch", - "Fwd + autograd backward, standard Python dispatch.", - aten, - quack, - oink, - "fwdbwd", - ) - print_table( - "Forward — CUDA Graph (zero Python overhead)", - "Kernel captured once, replayed without re-entering Python.", - aten, - quack, - oink, - "fwd_graph", - ) - print_table( - "Forward + Backward — CUDA Graph (zero Python overhead)", - "Fwd + bwd captured once, replayed without re-entering Python.", - aten, - quack, - oink, - "fwdbwd_graph", - ) - - print() - print("Done.") - - -if __name__ == "__main__": - main() diff --git a/oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py b/oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py index e137de7..1038597 100644 --- a/oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py +++ b/oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py @@ -18,7 +18,7 @@ import csv import os from dataclasses import dataclass -from typing import List, Optional, Tuple +from typing import List, Tuple import torch from triton.testing import do_bench as triton_do_bench @@ -29,6 +29,8 @@ from bench_utils import ( # noqa: E402 ErrorStatsAccumulator, collect_device_meta, + detect_hbm_peak_gbps, + dsv4_norm_configs, ensure_blackwell_arch_env, ensure_oink_src_on_path, error_stats_to_row, @@ -55,17 +57,6 @@ } -def detect_hbm_peak_gbps(device: Optional[torch.device] = None) -> float: - """Approximate HBM peak bandwidth in GB/s for roofline fractions.""" - if device is None: - device = torch.device("cuda") - props = torch.cuda.get_device_properties(device) - sm = props.major * 10 + props.minor - if sm >= 100: - return 8000.0 - return 2000.0 - - @dataclass class Result: ms: float @@ -360,6 +351,11 @@ def main() -> None: action="store_true", help="Run DSv3 set: M in {4096,16384,65536}, N in {6144,7168,8192}", ) + p.add_argument( + "--dsv4", + action="store_true", + help="Run DSv4 norm set: M in {4096,16384,65536}, N in {7168,1536,512}", + ) p.add_argument( "--skip-verify", action="store_true", @@ -378,6 +374,8 @@ def main() -> None: cfgs = [(bs * sl, hidden) for (bs, sl, hidden) in quack_suite_configs()] elif args.dsv3: cfgs = dsv3_configs() + elif args.dsv4: + cfgs = dsv4_norm_configs() else: cfgs = parse_configs(args.configs) diff --git a/oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py b/oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py index 809a475..fccafed 100644 --- a/oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py +++ b/oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py @@ -28,6 +28,7 @@ collect_device_meta, detect_hbm_peak_gbps, do_bench_triton, + dsv4_norm_configs, ensure_blackwell_arch_env, error_stats_to_row, ensure_oink_src_on_path, @@ -285,6 +286,11 @@ def main() -> None: action="store_true", help="Run DSv3 set: M in {4096,16384,65536}, N in {6144,7168,8192}", ) + p.add_argument( + "--dsv4", + action="store_true", + help="Run DSv4 norm set: M in {4096,16384,65536}, N in {7168,1536,512}", + ) p.add_argument( "--skip-verify", action="store_true", @@ -303,6 +309,8 @@ def main() -> None: cfgs = [(bs * sl, hidden) for (bs, sl, hidden) in quack_suite_configs()] elif args.dsv3: cfgs = dsv3_configs() + elif args.dsv4: + cfgs = dsv4_norm_configs() else: cfgs = parse_configs(args.configs) diff --git a/oink/benchmarks/media/sm103_bf16_oink_vs_quack_with_layernorm.svg b/oink/benchmarks/media/sm103_bf16_oink_vs_quack_with_layernorm.svg new file mode 100644 index 0000000..0e16103 --- /dev/null +++ b/oink/benchmarks/media/sm103_bf16_oink_vs_quack_with_layernorm.svg @@ -0,0 +1,2627 @@ + + + + + + + + 2026-04-29T13:05:02.928069 + image/svg+xml + + + Matplotlib v3.10.9, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/oink/benchmarks/readme/plot_quack_style_svg.py b/oink/benchmarks/readme/plot_quack_style_svg.py index 88eebdf..5e0224a 100644 --- a/oink/benchmarks/readme/plot_quack_style_svg.py +++ b/oink/benchmarks/readme/plot_quack_style_svg.py @@ -13,8 +13,8 @@ # limitations under the License. """ -Generate Quack-style SVG performance plots (Oink vs Quack) from the SM100 suite -JSON artifacts under `/tmp/kernelagent_oink_sm100_suite_{bf16,fp16}`. +Generate Quack-style SVG performance plots (Oink vs Quack) from SM10x suite +JSON artifacts under a suite output directory. The intent is to match Quack's README visual style: - 3 horizontal panels (suite-dependent): @@ -240,12 +240,13 @@ def _plot( label="Quack", ) if roofline_gbps is not None: + roof_label = f"HBM peak (measured {roofline_gbps / 1000.0:.3f} TB/s)" ax.axhline( roofline_gbps, color=COLOR_ROOF, linewidth=3, linestyle=(0, (4, 6)), - label="HBM peak (measured)" if ax is axes[0] else None, + label=roof_label if ax is axes[0] else None, ) max_y = max(max_y, float(roofline_gbps)) @@ -389,7 +390,19 @@ def main() -> None: "--roofline-json", type=str, default=None, - help="Optional /tmp/hbm_roofline_sm100_*.json path", + help="Optional measured roofline JSON path from benchmark_hbm_roofline_sm100.py", + ) + p.add_argument( + "--roofline-gbps", + type=float, + default=None, + help="Optional measured roofline in GB/s (mutually exclusive with --roofline-json).", + ) + p.add_argument( + "--arch-label", + type=str, + default="SM100", + help="Architecture label used in auto-generated titles, e.g. 'SM103 / GB300'.", ) p.add_argument("--out", type=str, required=True, help="Output SVG path") p.add_argument( @@ -401,8 +414,12 @@ def main() -> None: if not os.path.isdir(in_dir): raise SystemExit(f"--in-dir is not a directory: {in_dir}") + if args.roofline_json is not None and args.roofline_gbps is not None: + raise SystemExit("Use only one of --roofline-json or --roofline-gbps.") roofline_gbps = ( - _read_roofline_gbps(args.roofline_json) if args.roofline_json else None + float(args.roofline_gbps) + if args.roofline_gbps is not None + else (_read_roofline_gbps(args.roofline_json) if args.roofline_json else None) ) panel_files = list(_panel_files_for_suite(str(args.suite))) @@ -451,10 +468,11 @@ def main() -> None: if (args.suite == "quack_suite" and args.include_layernorm) else "" ) + arch_label = str(args.arch_label) if args.suite == "dsv3_cross_entropy": - title = f"SM100 {dtype.upper()} — {suite_name}{suffix}" + title = f"{arch_label} {dtype.upper()} — {suite_name}{suffix}" else: - title = f"SM100 {dtype.upper()} Kernel Benchmarks (Oink vs Quack) — {suite_name}{suffix}" + title = f"{arch_label} {dtype.upper()} Kernel Benchmarks (Oink vs Quack) — {suite_name}{suffix}" _plot( panels=panels, diff --git a/oink/benchmarks/readme/run_sm100_suite.py b/oink/benchmarks/readme/run_sm100_suite.py index 05bd211..4ef1bf4 100644 --- a/oink/benchmarks/readme/run_sm100_suite.py +++ b/oink/benchmarks/readme/run_sm100_suite.py @@ -56,6 +56,11 @@ def main() -> None: action="store_true", help="Skip correctness checks (Oink/Quack vs PyTorch / pure-PyTorch references)", ) + p.add_argument( + "--include-dsv4", + action="store_true", + help="Also run DeepSeek-V4-Flash norm workloads (RMSNorm N={7168,1536,512}; LayerNorm/fused-add N=7168).", + ) p.add_argument( "--dry-run", action="store_true", help="Print commands without executing them" ) @@ -83,260 +88,370 @@ def script(name: str) -> str: if args.skip_verify: common = [*common, "--skip-verify"] - runs: List[Tuple[str, List[str]]] = [ - ( - "rmsnorm_fwd_quack_suite_wfp32", - [ - py, - script("benchmark_rmsnorm_sm100.py"), - *common, - "--weight-dtype", - "fp32", - "--quack-suite", - "--iters", - "200", - "--warmup-ms", - "25", - "--json", - os.path.join(out_dir, "rmsnorm_fwd_quack_suite_wfp32.json"), - ], - ), - ( - "rmsnorm_fwd_dsv3_wfp32", - [ - py, - script("benchmark_rmsnorm_sm100.py"), - *common, - "--weight-dtype", - "fp32", - "--dsv3", - "--iters", - "200", - "--warmup-ms", - "25", - "--json", - os.path.join(out_dir, "rmsnorm_fwd_dsv3_wfp32.json"), - ], - ), - ( - "rmsnorm_bwd_quack_suite_wfp32", - [ - py, - script("benchmark_rmsnorm_bwd_sm100.py"), - *common, - "--weight-dtype", - "fp32", - "--quack-suite", - "--iters", - "100", - "--warmup-ms", - "25", - "--json", - os.path.join(out_dir, "rmsnorm_bwd_quack_suite_wfp32.json"), - ], - ), - ( - "rmsnorm_bwd_dsv3_wfp32", - [ - py, - script("benchmark_rmsnorm_bwd_sm100.py"), - *common, - "--weight-dtype", - "fp32", - "--dsv3", - "--iters", - "100", - "--warmup-ms", - "25", - "--json", - os.path.join(out_dir, "rmsnorm_bwd_dsv3_wfp32.json"), - ], - ), - # vLLM inference-style RMSNorm (weight dtype == activation dtype). - ( - "rmsnorm_fwd_quack_suite_wsame", - [ - py, - script("benchmark_rmsnorm_sm100.py"), - *common, - "--weight-dtype", - "same", - "--quack-suite", - "--iters", - "200", - "--warmup-ms", - "25", - "--json", - os.path.join(out_dir, "rmsnorm_fwd_quack_suite_wsame.json"), - ], - ), - ( - "rmsnorm_fwd_dsv3_wsame", - [ - py, - script("benchmark_rmsnorm_sm100.py"), - *common, - "--weight-dtype", - "same", - "--dsv3", - "--iters", - "200", - "--warmup-ms", - "25", - "--json", - os.path.join(out_dir, "rmsnorm_fwd_dsv3_wsame.json"), - ], - ), - ( - "rmsnorm_bwd_quack_suite_wsame", - [ - py, - script("benchmark_rmsnorm_bwd_sm100.py"), - *common, - "--weight-dtype", - "same", - "--quack-suite", - "--iters", - "100", - "--warmup-ms", - "25", - "--json", - os.path.join(out_dir, "rmsnorm_bwd_quack_suite_wsame.json"), - ], - ), - ( - "rmsnorm_bwd_dsv3_wsame", - [ - py, - script("benchmark_rmsnorm_bwd_sm100.py"), - *common, - "--weight-dtype", - "same", - "--dsv3", - "--iters", - "100", - "--warmup-ms", - "25", - "--json", - os.path.join(out_dir, "rmsnorm_bwd_dsv3_wsame.json"), - ], - ), - ( - "fused_add_rmsnorm_dsv3", - [ - py, - script("benchmark_fused_add_rmsnorm_sm100.py"), - *common, - "--dsv3", - "--quack-baseline", - "kernel_inplace", - "--iters", - "200", - "--warmup-ms", - "25", - "--json", - os.path.join(out_dir, "fused_add_rmsnorm_dsv3.json"), - ], - ), - ( - "softmax_fwd_bwd_quack_suite", - [ - py, - script("benchmark_softmax_sm100.py"), - *common, - "--mode", - "fwd_bwd", - "--quack-suite", - "--iters", - "50", - "--warmup-ms", - "25", - "--json", - os.path.join(out_dir, "softmax_fwd_bwd_quack_suite.json"), - ], - ), - ( - "softmax_fwd_bwd_dsv3", - [ - py, - script("benchmark_softmax_sm100.py"), - *common, - "--mode", - "fwd_bwd", - "--dsv3", - "--iters", - "50", - "--warmup-ms", - "25", - "--json", - os.path.join(out_dir, "softmax_fwd_bwd_dsv3.json"), - ], - ), - ( - "cross_entropy_fwd_bwd_quack_suite", - [ - py, - script("benchmark_cross_entropy_sm100.py"), - *common, - "--mode", - "fwd_bwd", - "--quack-suite", - "--iters", - "50", - "--warmup-ms", - "25", - "--json", - os.path.join(out_dir, "cross_entropy_fwd_bwd_quack_suite.json"), - ], - ), - ( - "cross_entropy_fwd_bwd_dsv3", - [ - py, - script("benchmark_cross_entropy_sm100.py"), - *common, - "--mode", - "fwd_bwd", - "--dsv3", - "--iters", - "50", - "--warmup-ms", - "25", - "--json", - os.path.join(out_dir, "cross_entropy_fwd_bwd_dsv3.json"), - ], - ), - ( - "layernorm_fwd_quack_suite", - [ - py, - script("benchmark_layernorm_sm100.py"), - *common, - "--quack-suite", - "--iters", - "200", - "--warmup-ms", - "25", - "--json", - os.path.join(out_dir, "layernorm_fwd_quack_suite.json"), - ], - ), - ( - "layernorm_fwd_dsv3", + runs: List[Tuple[str, List[str]]] = [] + + if args.include_dsv4: + runs.extend( [ - py, - script("benchmark_layernorm_sm100.py"), - *common, - "--dsv3", - "--iters", - "200", - "--warmup-ms", - "25", - "--json", - os.path.join(out_dir, "layernorm_fwd_dsv3.json"), - ], - ), - ] + ( + "rmsnorm_fwd_dsv4_wfp32", + [ + py, + script("benchmark_rmsnorm_sm100.py"), + *common, + "--weight-dtype", + "fp32", + "--dsv4", + "--iters", + "200", + "--warmup-ms", + "25", + "--json", + os.path.join(out_dir, "rmsnorm_fwd_dsv4_wfp32.json"), + ], + ), + ( + "rmsnorm_fwd_dsv4_wsame", + [ + py, + script("benchmark_rmsnorm_sm100.py"), + *common, + "--weight-dtype", + "same", + "--dsv4", + "--iters", + "200", + "--warmup-ms", + "25", + "--json", + os.path.join(out_dir, "rmsnorm_fwd_dsv4_wsame.json"), + ], + ), + ( + "rmsnorm_bwd_dsv4_wfp32", + [ + py, + script("benchmark_rmsnorm_bwd_sm100.py"), + *common, + "--weight-dtype", + "fp32", + "--dsv4", + "--iters", + "100", + "--warmup-ms", + "25", + "--json", + os.path.join(out_dir, "rmsnorm_bwd_dsv4_wfp32.json"), + ], + ), + ( + "rmsnorm_bwd_dsv4_wsame", + [ + py, + script("benchmark_rmsnorm_bwd_sm100.py"), + *common, + "--weight-dtype", + "same", + "--dsv4", + "--iters", + "100", + "--warmup-ms", + "25", + "--json", + os.path.join(out_dir, "rmsnorm_bwd_dsv4_wsame.json"), + ], + ), + ( + "fused_add_rmsnorm_dsv4", + [ + py, + script("benchmark_fused_add_rmsnorm_sm100.py"), + *common, + "--dsv4", + "--quack-baseline", + "kernel_inplace", + "--iters", + "200", + "--warmup-ms", + "25", + "--json", + os.path.join(out_dir, "fused_add_rmsnorm_dsv4.json"), + ], + ), + ( + "layernorm_fwd_dsv4", + [ + py, + script("benchmark_layernorm_sm100.py"), + *common, + "--dsv4", + "--iters", + "200", + "--warmup-ms", + "25", + "--json", + os.path.join(out_dir, "layernorm_fwd_dsv4.json"), + ], + ), + ] + ) + + runs.extend( + [ + ( + "rmsnorm_fwd_quack_suite_wfp32", + [ + py, + script("benchmark_rmsnorm_sm100.py"), + *common, + "--weight-dtype", + "fp32", + "--quack-suite", + "--iters", + "200", + "--warmup-ms", + "25", + "--json", + os.path.join(out_dir, "rmsnorm_fwd_quack_suite_wfp32.json"), + ], + ), + ( + "rmsnorm_fwd_dsv3_wfp32", + [ + py, + script("benchmark_rmsnorm_sm100.py"), + *common, + "--weight-dtype", + "fp32", + "--dsv3", + "--iters", + "200", + "--warmup-ms", + "25", + "--json", + os.path.join(out_dir, "rmsnorm_fwd_dsv3_wfp32.json"), + ], + ), + ( + "rmsnorm_bwd_quack_suite_wfp32", + [ + py, + script("benchmark_rmsnorm_bwd_sm100.py"), + *common, + "--weight-dtype", + "fp32", + "--quack-suite", + "--iters", + "100", + "--warmup-ms", + "25", + "--json", + os.path.join(out_dir, "rmsnorm_bwd_quack_suite_wfp32.json"), + ], + ), + ( + "rmsnorm_bwd_dsv3_wfp32", + [ + py, + script("benchmark_rmsnorm_bwd_sm100.py"), + *common, + "--weight-dtype", + "fp32", + "--dsv3", + "--iters", + "100", + "--warmup-ms", + "25", + "--json", + os.path.join(out_dir, "rmsnorm_bwd_dsv3_wfp32.json"), + ], + ), + # vLLM inference-style RMSNorm (weight dtype == activation dtype). + ( + "rmsnorm_fwd_quack_suite_wsame", + [ + py, + script("benchmark_rmsnorm_sm100.py"), + *common, + "--weight-dtype", + "same", + "--quack-suite", + "--iters", + "200", + "--warmup-ms", + "25", + "--json", + os.path.join(out_dir, "rmsnorm_fwd_quack_suite_wsame.json"), + ], + ), + ( + "rmsnorm_fwd_dsv3_wsame", + [ + py, + script("benchmark_rmsnorm_sm100.py"), + *common, + "--weight-dtype", + "same", + "--dsv3", + "--iters", + "200", + "--warmup-ms", + "25", + "--json", + os.path.join(out_dir, "rmsnorm_fwd_dsv3_wsame.json"), + ], + ), + ( + "rmsnorm_bwd_quack_suite_wsame", + [ + py, + script("benchmark_rmsnorm_bwd_sm100.py"), + *common, + "--weight-dtype", + "same", + "--quack-suite", + "--iters", + "100", + "--warmup-ms", + "25", + "--json", + os.path.join(out_dir, "rmsnorm_bwd_quack_suite_wsame.json"), + ], + ), + ( + "rmsnorm_bwd_dsv3_wsame", + [ + py, + script("benchmark_rmsnorm_bwd_sm100.py"), + *common, + "--weight-dtype", + "same", + "--dsv3", + "--iters", + "100", + "--warmup-ms", + "25", + "--json", + os.path.join(out_dir, "rmsnorm_bwd_dsv3_wsame.json"), + ], + ), + ( + "fused_add_rmsnorm_dsv3", + [ + py, + script("benchmark_fused_add_rmsnorm_sm100.py"), + *common, + "--dsv3", + "--quack-baseline", + "kernel_inplace", + "--iters", + "200", + "--warmup-ms", + "25", + "--json", + os.path.join(out_dir, "fused_add_rmsnorm_dsv3.json"), + ], + ), + ( + "softmax_fwd_bwd_quack_suite", + [ + py, + script("benchmark_softmax_sm100.py"), + *common, + "--mode", + "fwd_bwd", + "--quack-suite", + "--iters", + "50", + "--warmup-ms", + "25", + "--json", + os.path.join(out_dir, "softmax_fwd_bwd_quack_suite.json"), + ], + ), + ( + "softmax_fwd_bwd_dsv3", + [ + py, + script("benchmark_softmax_sm100.py"), + *common, + "--mode", + "fwd_bwd", + "--dsv3", + "--iters", + "50", + "--warmup-ms", + "25", + "--json", + os.path.join(out_dir, "softmax_fwd_bwd_dsv3.json"), + ], + ), + ( + "cross_entropy_fwd_bwd_quack_suite", + [ + py, + script("benchmark_cross_entropy_sm100.py"), + *common, + "--mode", + "fwd_bwd", + "--quack-suite", + "--iters", + "50", + "--warmup-ms", + "25", + "--json", + os.path.join(out_dir, "cross_entropy_fwd_bwd_quack_suite.json"), + ], + ), + ( + "cross_entropy_fwd_bwd_dsv3", + [ + py, + script("benchmark_cross_entropy_sm100.py"), + *common, + "--mode", + "fwd_bwd", + "--dsv3", + "--iters", + "50", + "--warmup-ms", + "25", + "--json", + os.path.join(out_dir, "cross_entropy_fwd_bwd_dsv3.json"), + ], + ), + ( + "layernorm_fwd_quack_suite", + [ + py, + script("benchmark_layernorm_sm100.py"), + *common, + "--quack-suite", + "--iters", + "200", + "--warmup-ms", + "25", + "--json", + os.path.join(out_dir, "layernorm_fwd_quack_suite.json"), + ], + ), + ( + "layernorm_fwd_dsv3", + [ + py, + script("benchmark_layernorm_sm100.py"), + *common, + "--dsv3", + "--iters", + "200", + "--warmup-ms", + "25", + "--json", + os.path.join(out_dir, "layernorm_fwd_dsv3.json"), + ], + ), + ] + ) print(f"Writing results to: {out_dir}", flush=True) for name, cmd in runs: diff --git a/oink/benchmarks/readme/summarize_results.py b/oink/benchmarks/readme/summarize_results.py index 684694d..efc92e4 100644 --- a/oink/benchmarks/readme/summarize_results.py +++ b/oink/benchmarks/readme/summarize_results.py @@ -232,7 +232,7 @@ def main() -> None: raise SystemExit(f"No .json files found under: {in_dir}") out_parts: List[str] = [] - out_parts.append("# KernelAgent-Oink SM100 Benchmark Summary") + out_parts.append("# KernelAgent-Oink SM10x Benchmark Summary") out_parts.append("") out_parts.append(f"Input directory: `{in_dir}`") out_parts.append("") diff --git a/oink/pyproject.toml b/oink/pyproject.toml index e0f1927..9dedd92 100644 --- a/oink/pyproject.toml +++ b/oink/pyproject.toml @@ -5,12 +5,12 @@ build-backend = "setuptools.build_meta" [project] name = "kernelagent-oink" version = "0.1.0" -description = "CuTeDSL kernels for Blackwell (SM100), shipped as a vLLM plugin" +description = "CuTeDSL kernels for Blackwell SM10x (SM100-SM103), shipped as a vLLM plugin" readme = "README.md" requires-python = ">=3.10" license = {text = "Apache-2.0"} authors = [{name = "PyTorch Labs"}] -keywords = ["cuda", "cutlass", "cute", "cutedsl", "blackwell", "sm100", "vllm"] +keywords = ["cuda", "cutlass", "cute", "cutedsl", "blackwell", "sm100", "sm103", "gb300", "vllm"] classifiers = [ "License :: OSI Approved :: Apache Software License", "Programming Language :: Python :: 3", @@ -27,7 +27,7 @@ classifiers = [ # We intentionally do NOT depend on `torch` here because vLLM already pins and # provides a compatible PyTorch build. dependencies = [ - "nvidia-cutlass-dsl>=4.2.1", + "nvidia-cutlass-dsl>=4.4.2", "cuda-python", ] diff --git a/oink/src/kernelagent_oink/__init__.py b/oink/src/kernelagent_oink/__init__.py index d61b60e..4fe33b1 100644 --- a/oink/src/kernelagent_oink/__init__.py +++ b/oink/src/kernelagent_oink/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. """ -KernelAgent-Oink: SM100 CuTeDSL kernels + optional vLLM plugin. +KernelAgent-Oink: Blackwell SM10x CuTeDSL kernels + optional vLLM plugin. This package can be loaded as a vLLM "general plugin" (entrypoint group `vllm.general_plugins`). In that mode it registers Oink custom ops only when @@ -135,7 +135,7 @@ def register(*, force: bool = False) -> None: def register_all_kernels(*, force: bool = False) -> None: """Override aten ops with Oink's kernels. - Checks CUDA/SM100/deps, sets up the CuTeDSL environment, then overrides + Checks CUDA/Blackwell SM10x/deps, sets up the CuTeDSL environment, then overrides ``aten::_fused_rms_norm`` and ``aten::_fused_rms_norm_backward`` on CUDA. Does NOT register ``torch.ops.oink.*`` custom ops — use :func:`register` diff --git a/oink/src/kernelagent_oink/blackwell/_cutedsl_cache.py b/oink/src/kernelagent_oink/blackwell/_cutedsl_cache.py new file mode 100644 index 0000000..2e2536d --- /dev/null +++ b/oink/src/kernelagent_oink/blackwell/_cutedsl_cache.py @@ -0,0 +1,49 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CuTeDSL cache setup shared by Blackwell kernel modules. + +CuTeDSL cache bytecode is version-sensitive. The default global cache +(``/tmp/$USER/cutlass_python_cache``) can be shared across environments with +incompatible ``nvidia-cutlass-dsl`` versions, producing noisy warnings and +lost cache reuse. Call this helper before importing ``cutlass`` in modules +that compile CuTeDSL kernels. +""" + +from __future__ import annotations + +import importlib.metadata +import os +import re + + +def ensure_versioned_cutedsl_cache_dir() -> None: + """Set a version-scoped CuTeDSL cache directory when the user did not. + + The path format intentionally matches the historical per-module logic: + ``$TMPDIR/$USER/cutlass_python_cache_``. + If ``CUTE_DSL_CACHE_DIR`` is already set, leave it untouched. + """ + if "CUTE_DSL_CACHE_DIR" in os.environ: + return + try: + dsl_ver = importlib.metadata.version("nvidia-cutlass-dsl") + except Exception: + dsl_ver = "unknown" + dsl_ver = re.sub(r"[^0-9A-Za-z]+", "_", dsl_ver) + user = os.environ.get("USER") or os.environ.get("USERNAME") or "user" + tmp = os.environ.get("TMPDIR") or "/tmp" + os.environ["CUTE_DSL_CACHE_DIR"] = os.path.join( + tmp, user, f"cutlass_python_cache_{dsl_ver}" + ) diff --git a/oink/src/kernelagent_oink/blackwell/_rmsnorm_impl.py b/oink/src/kernelagent_oink/blackwell/_rmsnorm_impl.py index cbea22d..3af5ea3 100644 --- a/oink/src/kernelagent_oink/blackwell/_rmsnorm_impl.py +++ b/oink/src/kernelagent_oink/blackwell/_rmsnorm_impl.py @@ -16,25 +16,14 @@ from __future__ import annotations -import importlib.metadata import os -import re from dataclasses import dataclass, replace # Vendored/adapted from Quack's SM100 RMSNorm with Oink-specific B200 tuning. -# CuTeDSL cache bytecode is version-sensitive, so isolate the default cache. -if "CUTE_DSL_CACHE_DIR" not in os.environ: - try: - _dsl_ver = importlib.metadata.version("nvidia-cutlass-dsl") - except Exception: - _dsl_ver = "unknown" - _dsl_ver = re.sub(r"[^0-9A-Za-z]+", "_", _dsl_ver) - _user = os.environ.get("USER") or os.environ.get("USERNAME") or "user" - _tmp = os.environ.get("TMPDIR") or "/tmp" - os.environ["CUTE_DSL_CACHE_DIR"] = os.path.join( - _tmp, _user, f"cutlass_python_cache_{_dsl_ver}" - ) +from kernelagent_oink.blackwell._cutedsl_cache import ensure_versioned_cutedsl_cache_dir + +ensure_versioned_cutedsl_cache_dir() try: import cutlass # type: ignore # noqa: F401 @@ -182,7 +171,16 @@ def _resolve_forward_launch_config( weight_dtype: type[cutlass.Numeric] | None, aligned_tensors: tuple[Tensor | None, ...], ) -> _ForwardLaunchConfig: - direct_gmem_default = bool(dtype.width == 16 and N in {128, 4096, 6144, 7168, 8192}) + direct_gmem_default = bool( + dtype.width == 16 and N in {128, 512, 4096, 6144, 7168, 8192} + ) + if ( + dtype.width == 16 + and N == 1536 + and weight_dtype is not None + and weight_dtype.width == 16 + ): + direct_gmem_default = True if weight_dtype is not None and weight_dtype.width == 32 and N == 7168: direct_gmem_default = False direct_gmem = _direct_gmem_from_policy(default=direct_gmem_default) @@ -197,7 +195,7 @@ def _resolve_forward_launch_config( default_copy_bits = 256 if can_use_256 else 128 if dtype.width == 16 and N == 128: default_copy_bits = 128 - if dtype.width == 16 and N == 4096: + if dtype.width == 16 and N in {512, 1536, 4096}: default_copy_bits = 128 if dtype.width == 16 and weight_dtype is not None and weight_dtype.width == 32: default_copy_bits = 128 if N == 4096 else 64 @@ -205,6 +203,12 @@ def _resolve_forward_launch_config( copy_bits = _copy_bits_from_policy( default=default_copy_bits, can_use_256=can_use_256 ) + # cp.async supports at most 128 bits per instruction. The copy atom clamps + # async copies to 128b, so keep the TV layout's vector width in sync with the + # emitted copy width; otherwise shapes such as DSv4 N=1536 can leave half of + # each logical vector tile uninitialized. + if use_async and copy_bits > 128: + copy_bits = 128 if use_async and copy_bits < 128: use_async = False @@ -284,6 +288,16 @@ def _forward_launch_overrides( nt_default: int | None = None cluster_n_default: int | None = None + if ( + dtype.width == 16 + and weight_dtype is not None + and weight_dtype.width == 16 + and N == 1536 + and direct_gmem + and M >= 4096 + ): + tpr_default = 32 + nt_default = 32 if ( dtype.width == 16 and weight_dtype is not None diff --git a/oink/src/kernelagent_oink/blackwell/_rmsnorm_simple_weightonly.py b/oink/src/kernelagent_oink/blackwell/_rmsnorm_simple_weightonly.py index 2a790f6..63e5083 100644 --- a/oink/src/kernelagent_oink/blackwell/_rmsnorm_simple_weightonly.py +++ b/oink/src/kernelagent_oink/blackwell/_rmsnorm_simple_weightonly.py @@ -1,8 +1,23 @@ +"""Measured same-dtype bf16 RMSNorm forward specializations. + +This module implements a narrow fast path for row-major bf16 tensors with a +bf16 1D weight and no residual/bias/rstd outputs. The math is +``y = x * rsqrt(mean(x * x) + eps) * weight`` with fp32 reduction/multiply and +bf16 output. The shape table below is deliberately measured and narrow; shapes +not listed fall back to the generic RMSNorm pointer path. +""" +# ruff: noqa: E402 # CuTeDSL cache setup must run before importing cutlass. + +from dataclasses import dataclass + import cuda.bindings.driver as cuda import torch -from dataclasses import dataclass from torch import Tensor +from kernelagent_oink.blackwell._cutedsl_cache import ensure_versioned_cutedsl_cache_dir + +ensure_versioned_cutedsl_cache_dir() + import cutlass import cutlass.cute as cute import cutlass.utils as utils @@ -19,6 +34,10 @@ _COMPILED_CACHE: dict[tuple[object, int, int, int], object] = {} _SIMPLE_WEIGHTONLY_SHAPES: dict[tuple[int, int], tuple[int, int]] = { + # DeepSeek-V4-Flash q_lora same-dtype RMSNorm shape. Larger M and the + # kv/per-head N=512 cases are faster through the generic pointer path on SM103. + (4096, 1536): (96, 96), + # DeepSeek-V3 hidden-state same-dtype RMSNorm shapes. (4096, 6144): (192, 192), (4096, 7168): (224, 224), (4096, 8192): (256, 256), @@ -66,7 +85,12 @@ def cache_key(self) -> tuple[object, int, int, int]: ) @staticmethod - def make_tv_layout(threads_per_row, rows_per_block, vec_size, num_vec_blocks): + def make_tv_layout( + threads_per_row: int, + rows_per_block: int, + vec_size: int, + num_vec_blocks: int, + ): shape = ((threads_per_row, rows_per_block), (vec_size, num_vec_blocks)) stride = ( (vec_size * rows_per_block, 1), @@ -74,7 +98,7 @@ def make_tv_layout(threads_per_row, rows_per_block, vec_size, num_vec_blocks): ) return shape, stride - def smem_bytes(self): + def smem_bytes(self) -> int: return ( self.rows_per_block * self.cols_per_tile * (self.dtype.width // 8) + self.rows_per_block * self.warps_per_row * 4 @@ -210,10 +234,6 @@ def kernel( cute.copy(copy_atom_store, tXrO, tXgO) -def _can_use_simple_weightonly(x: Tensor, weight: Tensor, out: Tensor) -> bool: - return _get_simple_weightonly_config(x, weight, out) is not None - - def _get_simple_weightonly_config( x: Tensor, weight: Tensor, diff --git a/oink/src/kernelagent_oink/blackwell/cross_entropy.py b/oink/src/kernelagent_oink/blackwell/cross_entropy.py index d8b37ea..5bb1568 100644 --- a/oink/src/kernelagent_oink/blackwell/cross_entropy.py +++ b/oink/src/kernelagent_oink/blackwell/cross_entropy.py @@ -1,3 +1,4 @@ +# ruff: noqa: E402 # CuTeDSL cache setup must run before importing cutlass. # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -39,10 +40,7 @@ from __future__ import annotations -import importlib.metadata import math -import os -import re from typing import Literal, Optional, Type import torch @@ -50,21 +48,9 @@ import cuda.bindings.driver as cuda # provided by NVIDIA cuda-python -# CuTeDSL caches generated MLIR into a tempdir under a global default -# (`/tmp/$USER/cutlass_python_cache`). The cache bytecode format can differ across -# `nvidia-cutlass-dsl` versions, and cross-version cache sharing causes noisy -# warnings (and disables cache reuse). -if "CUTE_DSL_CACHE_DIR" not in os.environ: - try: - _dsl_ver = importlib.metadata.version("nvidia-cutlass-dsl") - except Exception: - _dsl_ver = "unknown" - _dsl_ver = re.sub(r"[^0-9A-Za-z]+", "_", _dsl_ver) - _user = os.environ.get("USER") or os.environ.get("USERNAME") or "user" - _tmp = os.environ.get("TMPDIR") or "/tmp" - os.environ["CUTE_DSL_CACHE_DIR"] = os.path.join( - _tmp, _user, f"cutlass_python_cache_{_dsl_ver}" - ) +from kernelagent_oink.blackwell._cutedsl_cache import ensure_versioned_cutedsl_cache_dir + +ensure_versioned_cutedsl_cache_dir() try: import cutlass # type: ignore # noqa: F401 diff --git a/oink/src/kernelagent_oink/blackwell/layernorm.py b/oink/src/kernelagent_oink/blackwell/layernorm.py index ada51ec..6b4b9c7 100644 --- a/oink/src/kernelagent_oink/blackwell/layernorm.py +++ b/oink/src/kernelagent_oink/blackwell/layernorm.py @@ -1,3 +1,4 @@ +# ruff: noqa: E402 # CuTeDSL cache setup must run before importing cutlass. # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -30,10 +31,7 @@ from __future__ import annotations -import importlib.metadata import math -import os -import re import operator from typing import Optional, Tuple, Type @@ -42,21 +40,9 @@ import cuda.bindings.driver as cuda # provided by NVIDIA cuda-python -# CuTeDSL caches generated MLIR into a tempdir under a global default -# (`/tmp/$USER/cutlass_python_cache`). The cache bytecode format can differ across -# `nvidia-cutlass-dsl` versions, and cross-version cache sharing causes noisy -# warnings (and disables cache reuse). -if "CUTE_DSL_CACHE_DIR" not in os.environ: - try: - _dsl_ver = importlib.metadata.version("nvidia-cutlass-dsl") - except Exception: - _dsl_ver = "unknown" - _dsl_ver = re.sub(r"[^0-9A-Za-z]+", "_", _dsl_ver) - _user = os.environ.get("USER") or os.environ.get("USERNAME") or "user" - _tmp = os.environ.get("TMPDIR") or "/tmp" - os.environ["CUTE_DSL_CACHE_DIR"] = os.path.join( - _tmp, _user, f"cutlass_python_cache_{_dsl_ver}" - ) +from kernelagent_oink.blackwell._cutedsl_cache import ensure_versioned_cutedsl_cache_dir + +ensure_versioned_cutedsl_cache_dir() try: import cutlass # type: ignore # noqa: F401 diff --git a/oink/src/kernelagent_oink/blackwell/oink_custom_ops.py b/oink/src/kernelagent_oink/blackwell/oink_custom_ops.py index 282514d..348b4f5 100644 --- a/oink/src/kernelagent_oink/blackwell/oink_custom_ops.py +++ b/oink/src/kernelagent_oink/blackwell/oink_custom_ops.py @@ -16,7 +16,7 @@ Torch custom ops wrapping Oink's Blackwell RMSNorm kernels. These ops are designed to be: -- Architecture-aware (use CuTeDSL SM100 kernels when available, fall back +- Architecture-aware (use CuTeDSL Blackwell SM10x kernels when available, fall back to a safe reference elsewhere). - Layout-preserving for 2D row-major inputs, including padded MLA-style layouts where stride(0) > N and stride(1) == 1. @@ -69,7 +69,7 @@ def _get_rmsnorm_mod(): def _get_sm(device: torch.device | None = None) -> int: - """Return SM version as an int (e.g., 100 for SM100 / Blackwell).""" + """Return SM version as an int (e.g., 103 for SM103 / Blackwell).""" if device is None: device = torch.device("cuda") major, minor = torch.cuda.get_device_capability(device) @@ -95,7 +95,7 @@ def oink_rmsnorm( dimension stride(0) may be larger than N (padded-row layouts), and will be preserved on the fast CuTeDSL path. - On SM100 (and newer), this dispatches to the tuned CuTeDSL Blackwell + On Blackwell SM10x (SM100 and newer), this dispatches to the tuned CuTeDSL Blackwell RMSNorm kernel in rmsnorm.rmsnorm_forward, which in turn selects the best internal schedule (including DSv3-specific stage-2 kernels where applicable) and preserves the input's 2D stride when using the @@ -111,7 +111,7 @@ def oink_rmsnorm( sm = _get_sm(x.device) _rms = _get_rmsnorm_mod() if sm >= 100: - # Use the tuned CuTeDSL SM100 kernel. The public API already + # Use the tuned CuTeDSL Blackwell kernel. The public API already # contains all necessary gating and layout checks internally. y, _rstd, _res = _rms.rmsnorm_forward( x, @@ -186,13 +186,13 @@ def oink_fused_add_rms_norm( _rms = _get_rmsnorm_mod() if sm < 100: - # Non-SM100 fallback: keep semantics in-place (correctness-first). + # Non-SM10x fallback: keep semantics in-place (correctness-first). residual.add_(x) y = _rms.rmsnorm_ref(residual, w=weight, b=None, residual=None, eps=eps) x.copy_(y) return None - # SM100+: prefer the lowest-overhead in-place entrypoint (returns None). + # SM10x+: prefer the lowest-overhead in-place entrypoint (returns None). if hasattr(_rms, "fused_add_rmsnorm_inplace_"): _rms.fused_add_rmsnorm_inplace_( # type: ignore[misc] x, diff --git a/oink/src/kernelagent_oink/blackwell/softmax.py b/oink/src/kernelagent_oink/blackwell/softmax.py index 394ab48..d364c5c 100644 --- a/oink/src/kernelagent_oink/blackwell/softmax.py +++ b/oink/src/kernelagent_oink/blackwell/softmax.py @@ -1,3 +1,4 @@ +# ruff: noqa: E402 # CuTeDSL cache setup must run before importing cutlass. # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -26,9 +27,6 @@ from __future__ import annotations -import importlib.metadata -import os -import re from typing import Type import torch @@ -36,21 +34,9 @@ import cuda.bindings.driver as cuda # provided by NVIDIA cuda-python -# CuTeDSL caches generated MLIR into a tempdir under a global default -# (`/tmp/$USER/cutlass_python_cache`). The cache bytecode format can differ across -# `nvidia-cutlass-dsl` versions, and cross-version cache sharing causes noisy -# warnings (and disables cache reuse). -if "CUTE_DSL_CACHE_DIR" not in os.environ: - try: - _dsl_ver = importlib.metadata.version("nvidia-cutlass-dsl") - except Exception: - _dsl_ver = "unknown" - _dsl_ver = re.sub(r"[^0-9A-Za-z]+", "_", _dsl_ver) - _user = os.environ.get("USER") or os.environ.get("USERNAME") or "user" - _tmp = os.environ.get("TMPDIR") or "/tmp" - os.environ["CUTE_DSL_CACHE_DIR"] = os.path.join( - _tmp, _user, f"cutlass_python_cache_{_dsl_ver}" - ) +from kernelagent_oink.blackwell._cutedsl_cache import ensure_versioned_cutedsl_cache_dir + +ensure_versioned_cutedsl_cache_dir() try: import cutlass # type: ignore # noqa: F401