From e1d25766281a22abd6485633aba1289e0f98f288 Mon Sep 17 00:00:00 2001
From: Laura Wang <3700467+Laurawly@users.noreply.github.com>
Date: Wed, 29 Apr 2026 14:27:44 -0700
Subject: [PATCH 1/4] Clean up Oink benchmarks and SM103 docs
---
oink/README.md | 136 +-
oink/benchmarks/README.md | 155 +-
oink/benchmarks/benchmark/bench_utils.py | 26 +-
.../benchmark_fused_add_rmsnorm_sm100.py | 13 +-
.../benchmark/benchmark_layernorm_sm100.py | 8 +
.../benchmark/benchmark_paulius_rmsnorm.py | 242 --
.../benchmark/benchmark_rmsnorm_all.py | 325 --
.../benchmark/benchmark_rmsnorm_bwd_sm100.py | 22 +-
.../benchmark/benchmark_rmsnorm_sm100.py | 8 +
...m103_bf16_oink_vs_quack_with_layernorm.svg | 2627 +++++++++++++++++
.../benchmarks/readme/plot_quack_style_svg.py | 32 +-
oink/benchmarks/readme/run_sm100_suite.py | 117 +-
oink/benchmarks/readme/summarize_results.py | 2 +-
oink/pyproject.toml | 6 +-
oink/src/kernelagent_oink/__init__.py | 4 +-
.../blackwell/_cutedsl_cache.py | 49 +
.../blackwell/_rmsnorm_impl.py | 46 +-
.../blackwell/_rmsnorm_simple_weightonly.py | 32 +-
.../blackwell/cross_entropy.py | 20 +-
.../kernelagent_oink/blackwell/layernorm.py | 21 +-
.../blackwell/oink_custom_ops.py | 12 +-
.../src/kernelagent_oink/blackwell/softmax.py | 21 +-
22 files changed, 3111 insertions(+), 813 deletions(-)
delete mode 100644 oink/benchmarks/benchmark/benchmark_paulius_rmsnorm.py
delete mode 100644 oink/benchmarks/benchmark/benchmark_rmsnorm_all.py
create mode 100644 oink/benchmarks/media/sm103_bf16_oink_vs_quack_with_layernorm.svg
create mode 100644 oink/src/kernelagent_oink/blackwell/_cutedsl_cache.py
diff --git a/oink/README.md b/oink/README.md
index d21720a3..bcb10328 100644
--- a/oink/README.md
+++ b/oink/README.md
@@ -1,65 +1,61 @@
# KernelAgent-Oink
-KernelAgent-Oink is a small **CuTeDSL (CUTLASS DSL) kernel library** for
-**NVIDIA Blackwell (SM10x / GB200 / GB300 / B200-class)**, bundled as a lightweight
-Python package that can be used standalone or as a **vLLM general plugin**.
+KernelAgent-Oink is a lightweight **CuTeDSL (CUTLASS DSL) kernel package** for
+NVIDIA Blackwell **SM10x** GPUs. It can be used standalone or loaded as a
+**vLLM general plugin**.
-At the moment, the vLLM integration exposes the following `torch.library.custom_op`
-entrypoints under the `oink::` namespace:
+Current custom ops:
- `torch.ops.oink.rmsnorm(x, weight, eps) -> Tensor`
- `torch.ops.oink.fused_add_rms_norm(x, residual, weight, eps) -> None` (in-place)
-The package also includes additional SM100 kernels used by the benchmark suite:
-LayerNorm, Softmax (fwd+bwd), and CrossEntropy (fwd+bwd).
+The repo also contains benchmark-facing Blackwell kernels for LayerNorm, Softmax,
+and CrossEntropy.
## Requirements
-- GPU: **SM10x (Blackwell)** for the fast CuTeDSL paths. On other GPUs, Oink falls back to
- reference PyTorch implementations for correctness.
-- Python dependencies:
- - `nvidia-cutlass-dsl` (CuTeDSL)
- - `cuda-python`
- - `torch` (provided by your environment / vLLM)
+- Blackwell GPU for optimized CuTeDSL paths; other GPUs use correctness-first
+ PyTorch fallbacks.
+- `nvidia-cutlass-dsl>=4.4.2`
+- `cuda-python`
+- `torch` from the surrounding environment / vLLM
Recommended env vars:
```bash
-export CUTE_DSL_ARCH=sm_100a
export PYTORCH_ALLOC_CONF=expandable_segments:True
+export CUTE_DSL_ARCH=sm_103a # GB300 / SM103
+# export CUTE_DSL_ARCH=sm_100a # GB200/B200 / SM100
```
-On **GB300 / SM103**, prefer:
-
-```bash
-export CUTE_DSL_ARCH=sm_103a
-```
-
-## Install (editable)
+## Install
From the `KernelAgent` repo root:
```bash
pip install -e ./oink
+pip install -e "./oink[bench]" # optional benchmark/plot deps
```
-For running the in-repo benchmark suite / plots:
+A reproducible GB300 benchmark environment used for the results below:
```bash
-pip install -e "./oink[bench]"
+conda create -y -n cute python=3.12
+conda run -n cute python -m pip install --upgrade pip setuptools wheel packaging ninja
+conda run -n cute python -m pip install --upgrade --index-url https://download.pytorch.org/whl/cu130 torch
+conda run -n cute python -m pip install 'nvidia-cutlass-dsl==4.4.2' cuda-python triton matplotlib
+conda run -n cute python -m pip install -e './oink[bench]'
```
## Usage
-### vLLM (general plugin)
-
-1) Enable the plugin:
+### vLLM plugin
```bash
export VLLM_USE_OINK_RMSNORM=1
```
-2) Ensure vLLM keeps `rms_norm` as a custom op when using `torch.compile` / CUDA graphs:
+When using `torch.compile` / CUDA graphs, keep vLLM RMSNorm as a custom op:
```python
from vllm import LLM
@@ -72,12 +68,7 @@ llm = LLM(
)
```
-Without `+rms_norm`, Inductor may fuse RMSNorm into larger kernels and neither
-vLLM’s CUDA RMSNorm nor Oink will run.
-
-### Direct PyTorch usage (manual op registration)
-
-For standalone use (outside vLLM), register the custom ops once:
+### Direct PyTorch
```python
import kernelagent_oink
@@ -92,73 +83,40 @@ y = torch.ops.oink.rmsnorm(x, w, 1e-6)
## Benchmarks
-### GB200 / B200 (SM100) benchmark suite
-
-The repo includes a Quack-style benchmark suite (tables + SVG plots) to compare
-Oink against Quack and to reproduce the reported speedups. The pre-generated
-plots below were measured on **GB200 / B200-class SM100** systems.
-
-In short, Oink’s edge comes from lower pointer-path launch overhead plus Blackwell-tuned shape routing for both hot small-`M` and larger RMSNorm rows.
-
-On the current B200 forward sweep, Oink holds `1.12x` / `1.06x` geomean over Quack for same-dtype weights on the Quack-suite / DSv3 sets, and `1.18x` / `1.06x` for fp32 weights, with worst output rel-L2 `1.45e-5` (Quack `2.01e-5`).
-
-- How to run + methodology: `oink/benchmarks/README.md`
-- Pre-generated plots: `oink/benchmarks/media/`
-
-
-
-
-
-
-
-
-
-### GB300 (SM103) Q/K-norm results
+Benchmark details and commands are in [`benchmarks/README.md`](benchmarks/README.md).
+Reported numbers are correctness-gated against PyTorch references before timing.
-We also benchmarked the real Llama4x-style Q/K-norm workload on **GB300
-(SM103)** using non-contiguous `q` / `k` views produced by `qkv.split()`. This
-benchmark reports both the direct CuTeDSL/CUTLASS baseline and the optimized
-Oink path for the production strided `[M, N]` views. The CuTeDSL/CUTLASS
-baseline here is a **Q/K-norm adaptation** derived from the
-[CUTLASS CuTeDSL Blackwell RMSNorm example](https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/rmsnorm.py),
-not the example kernel used unchanged.
+Current GB300 / SM103 setup:
-For roofline context, we also plot the same workload using a dedicated
-useful-bandwidth harness: median CUDA-event timing plus a logical IO model of
-one read + one write of the fused `[M, N]` tensor. This is the physically
-meaningful view for comparing against the measured practical GB300 BF16 stream
-roof, whereas the steady-state CUDA-graph replay medians below are better read
-as a latency view.
+- NVIDIA GB300, capability `(10, 3)`, `CUTE_DSL_ARCH=sm_103a`
+- `torch==2.11.0+cu130`, CUDA `13.0`
+- `nvidia-cutlass-dsl==4.4.2`, `cuda-python==13.2.0`
+- measured BF16 STREAM-like roof: **7.140 TB/s**
-
+
-Representative steady-state CUDA-graph replay medians from one GB300 run are
-shown below (absolute microseconds may vary slightly run to run, but the
-ranking and trend were stable).
+Quack-suite BF16 summary (`N=4096`):
-- Q path: Oink is roughly **2.4–3.1x faster** than the CuTeDSL baseline on
- representative multi-row workloads.
-- K path: Oink is roughly **2.0–3.6x faster** on the same sweep.
+| op | rows | geomean vs Quack | large-row roofline note |
+|---|---:|---:|---|
+| RMSNorm fwd, weight=same | 19 | 1.019x | near measured roof on large rows |
+| RMSNorm fwd, weight=fp32 | 19 | 1.100x | near measured roof on large rows |
+| LayerNorm fwd | 19 | 1.241x | near measured roof on large rows |
+| Softmax fwd+bwd | 19 | 1.673x | near measured roof on large rows |
+| CrossEntropy fwd+bwd | 19 | 1.635x | mixed memory/SFU behavior |
-Takeaways from the GB300 Q/K-norm sweep:
+Historical plots remain under `benchmarks/media/`:
-- For the user-relevant multi-row workloads, Oink beats the CuTeDSL/CUTLASS
- baseline by comfortably more than 20%.
-- In the roofline view, Oink gets close to the practical GB300 BF16 streaming
- ceiling on the large-row Q/K shapes, while the CuTeDSL baseline stays much
- farther from the roof.
-- The only cases below 20% are the tiny single-row latency-floor microcases:
- Q `M=1` is ~12% faster and K `M=1` is ~6% faster.
-- Correctness spot-check from the same harness:
- - Q max diff vs eager: `0.03125`
- - K max diff vs eager: `0.007812`
+- `sm100_*`: historical SM100 / B200 runs.
+- `gb300_bf16_qk_norm_oink_vs_cutedsl_roofline.svg`: historical GB300 Q/K-norm
+ harness, separate from the Quack-suite table above.
## Links
| What | Link |
|---|---|
-| Quack (expert baseline) | https://github.com/Dao-AILab/quack |
-| KernelAgent (agentic framework) | https://github.com/meta-pytorch/KernelAgent |
-| vLLM PR (Oink RMSNorm integration) | https://github.com/vllm-project/vllm/pull/31828 |
+| Quack baseline | https://github.com/Dao-AILab/quack |
+| KernelAgent | https://github.com/meta-pytorch/KernelAgent |
+| vLLM Oink RMSNorm PR | https://github.com/vllm-project/vllm/pull/31828 |
diff --git a/oink/benchmarks/README.md b/oink/benchmarks/README.md
index 26f3c07a..c99966e1 100644
--- a/oink/benchmarks/README.md
+++ b/oink/benchmarks/README.md
@@ -1,30 +1,39 @@
-# SM100 Benchmarks (KernelAgent-Oink vs Quack)
+# Blackwell SM10x Benchmarks (KernelAgent-Oink vs Quack)
-This folder contains SM10x (GB200 / GB300 / Blackwell) microbenchmarks for the Oink
-CuTeDSL kernels vendored into KernelAgent, comparing against Quack’s SM100
-kernels where Quack provides an equivalent API.
+This folder contains SM10x (GB200 / GB300 / Blackwell) microbenchmarks for the
+Oink CuTeDSL kernels, comparing against Quack’s SM100 kernels where Quack
+provides an equivalent API.
## Prereqs
- GPU: **SM10x / Blackwell** (`torch.cuda.get_device_capability()[0] == 10`).
- Python deps in your environment:
- `torch`
- - `nvidia-cutlass-dsl` (CuTeDSL)
+ - `nvidia-cutlass-dsl>=4.4.2` (CuTeDSL)
- `cuda-python`
- `triton` (only for `triton.testing.do_bench`)
- - `quack` (optional; only needed for Oink-vs-Quack comparisons)
+ - `quack` / `quack-kernels` (optional; only needed for Oink-vs-Quack comparisons)
Recommended env vars:
```bash
export PYTORCH_ALLOC_CONF=expandable_segments:True
-export CUTE_DSL_ARCH=sm_100a
+# GB300 / SM103:
+export CUTE_DSL_ARCH=sm_103a
+# GB200/B200 / SM100 historical runs:
+# export CUTE_DSL_ARCH=sm_100a
```
-On **GB300 / SM103**, prefer:
+For the pinned GB300 / SM103 benchmark environment used by the current README
+numbers:
```bash
-export CUTE_DSL_ARCH=sm_103a
+conda create -y -n cute python=3.12
+conda run -n cute python -m pip install --upgrade pip setuptools wheel packaging ninja
+conda run -n cute python -m pip install --upgrade --index-url https://download.pytorch.org/whl/cu130 torch
+conda run -n cute python -m pip install 'nvidia-cutlass-dsl==4.4.2' cuda-python triton matplotlib pytest pytest-cov
+conda run -n cute python -m pip install -e '.[bench]'
+conda run -n cute python -m pip install 'git+https://github.com/Dao-AILab/quack.git' # optional comparison baseline
```
## Shape suites
@@ -34,21 +43,47 @@ export CUTE_DSL_ARCH=sm_103a
- **DeepSeek-V3-like (DSv3)**
- RMSNorm / LayerNorm / Softmax: `M ∈ {4096, 16384, 65536}`, `N ∈ {6144, 7168, 8192}`
- Cross-entropy: `M ∈ {4096, 16384, 65536}`, `N ∈ {3072, 6144, 8192, 12288}`
+- **DeepSeek-V4-Flash norm shapes (DSv4)** from `deepseek-ai/DeepSeek-V4-Flash/inference/model.py`
+ - hidden-state RMSNorm / LayerNorm: `M ∈ {4096, 16384, 65536}`, `N = 7168`
+ - q_lora RMSNorm: `M ∈ {4096, 16384, 65536}`, `N = 1536`
+ - kv latent / per-head RMSNorm: `M ∈ {4096, 16384, 65536}`, `N = 512`
## Correctness gates
-By default, each script runs a per-shape `torch.testing.assert_close` check
-vs a **pure-PyTorch reference** **before** emitting timing numbers. When Quack
-is available for that op/path, the script also validates Quack vs the *same*
+By default, each script runs a per-shape `torch.testing.assert_close` check vs a
+**pure-PyTorch reference** **before** emitting timing numbers. When Quack is
+available for that op/path, the script also validates Quack vs the *same*
reference (so speedups can’t come from looser numerics).
-Disable with `--skip-verify` only for quick smoke tests.
+Disable with `--skip-verify` only for quick smoke tests. Do not use
+`--skip-verify` for README or release performance numbers.
+
+## Roofline reporting
+
+Most benchmark JSONs include `*_hbm_frac` using `bench_utils.detect_hbm_peak_gbps()`.
+That helper is a coarse fallback (`8000 GB/s` for SM10x) so old JSONs can be
+compared consistently. For GB300/SM103 published results, use a measured roofline
+run instead.
+
+Current measured GB300 BF16 STREAM-like roof used in the README:
+
+- **7.140 TB/s** (triad, `BLOCK=2048`, `warps=8`)
+- 90% target: **6.426 TB/s**
+
+Regenerate on the current machine:
+
+```bash
+conda run -n cute bash -lc 'PYTHONNOUSERSITE=1 CUTE_DSL_ARCH=sm_103a \
+ python benchmarks/benchmark/benchmark_hbm_roofline_sm100.py --dtype bf16 --op both --gb 1 \
+ --json /tmp/oink_sm103_hbm_roofline_bf16_current.json'
+```
## Running benchmarks
-All scripts support:
+All primary scripts support:
-- `--quack-suite` or `--dsv3` (or `--configs MxN,...`)
+- `--quack-suite` or `--dsv3` (and `--dsv4` where applicable)
+- `--configs MxN,...`
- `--dtype {bf16,fp16,fp32}`
- `--iters ` and `--warmup-ms ` for kernel-only timing
- `--json ` and/or `--csv ` outputs (meta + rows)
@@ -59,100 +94,126 @@ Run the full Quack-suite + DSv3 set (Oink vs Quack) and write all JSON artifacts
to a timestamped directory:
```bash
-python oink/benchmarks/readme/run_sm100_suite.py --dtype bf16
+conda run -n cute bash -lc 'PYTHONNOUSERSITE=1 CUTE_DSL_ARCH=sm_103a \
+ python benchmarks/readme/run_sm100_suite.py --dtype bf16'
+
+# Include DeepSeek-V4-Flash norm workloads:
+conda run -n cute bash -lc 'PYTHONNOUSERSITE=1 CUTE_DSL_ARCH=sm_103a \
+ python benchmarks/readme/run_sm100_suite.py --dtype bf16 --include-dsv4 \
+ --out-dir /tmp/oink_sm103_suite_bf16_current'
```
-Turn the JSON artifacts into Markdown tables (with geomean speedups):
+Turn JSON artifacts into Markdown tables (with geomean speedups):
```bash
-python oink/benchmarks/readme/summarize_results.py --in-dir /tmp/kernelagent_oink_sm100_suite_ \
- --out /tmp/kernelagent_oink_sm100_suite_summary.md
+conda run -n cute bash -lc 'python benchmarks/readme/summarize_results.py \
+ --in-dir /tmp/oink_sm103_suite_bf16_current \
+ --out /tmp/oink_sm103_suite_bf16_current_summary.md'
```
-### Measured HBM roofline (STREAM-like)
-
-To contextualize the `*_tbps` numbers as a fraction of a *measured* bandwidth
-ceiling (rather than a theoretical spec), run:
+Generate SM103 SVGs from current JSONs and measured roofline:
```bash
-CUDA_VISIBLE_DEVICES=0 python oink/benchmarks/benchmark/benchmark_hbm_roofline_sm100.py --dtype bf16 --op both --gb 2 \
- --json /tmp/hbm_roofline_sm100_bf16.json
+conda run -n cute bash -lc 'python benchmarks/readme/plot_quack_style_svg.py \
+ --in-dir /tmp/oink_sm103_suite_bf16_current \
+ --suite quack_suite --include-layernorm \
+ --roofline-json /tmp/oink_sm103_hbm_roofline_bf16_current.json \
+ --arch-label "SM103 / GB300" \
+ --out benchmarks/media/sm103_bf16_oink_vs_quack_with_layernorm.svg'
+
+conda run -n cute bash -lc 'python benchmarks/readme/plot_quack_style_svg.py \
+ --in-dir /tmp/oink_sm103_suite_bf16_current \
+ --suite dsv3_all --shape-policy first \
+ --roofline-json /tmp/oink_sm103_hbm_roofline_bf16_current.json \
+ --arch-label "SM103 / GB300" \
+ --out benchmarks/media/sm103_bf16_oink_vs_quack_dsv3_all.svg'
```
+The existing `sm100_*` SVGs in `benchmarks/media/` are historical SM100/B200
+plots. Do not use them as GB300 evidence.
+
### RMSNorm forward
```bash
-python oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py --dtype bf16 --weight-dtype fp32 --quack-suite --iters 200 --warmup-ms 25 \
+python benchmarks/benchmark/benchmark_rmsnorm_sm100.py --dtype bf16 --weight-dtype fp32 --quack-suite --iters 200 --warmup-ms 25 \
--json /tmp/oink_rmsnorm_fwd_quack_suite.json
-python oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py --dtype bf16 --weight-dtype fp32 --dsv3 --iters 200 --warmup-ms 25 \
+python benchmarks/benchmark/benchmark_rmsnorm_sm100.py --dtype bf16 --weight-dtype fp32 --dsv3 --iters 200 --warmup-ms 25 \
--json /tmp/oink_rmsnorm_fwd_dsv3.json
# vLLM-style inference weights (weight dtype == activation dtype)
-python oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py --dtype bf16 --weight-dtype same --quack-suite --iters 200 --warmup-ms 25 \
+python benchmarks/benchmark/benchmark_rmsnorm_sm100.py --dtype bf16 --weight-dtype same --quack-suite --iters 200 --warmup-ms 25 \
--json /tmp/oink_rmsnorm_fwd_quack_suite_wsame.json
+
+# DeepSeek-V4-Flash norm grid
+python benchmarks/benchmark/benchmark_rmsnorm_sm100.py --dtype bf16 --weight-dtype same --dsv4 --iters 200 --warmup-ms 25 \
+ --json /tmp/oink_rmsnorm_fwd_dsv4_wsame.json
```
### Fused Add + RMSNorm (vLLM-style, in-place)
-This is a good "roofline case study" kernel (heavy read/write traffic, very little extra math):
+This is a good roofline case study kernel (heavy read/write traffic, very little
+extra math):
```bash
-CUDA_VISIBLE_DEVICES=0 python oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py --dtype bf16 --M 65536 --N 4096 \
+CUDA_VISIBLE_DEVICES=0 python benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py --dtype bf16 --M 65536 --N 4096 \
--json /tmp/fused_add_rmsnorm_sm100_bf16.json
```
-Note on the Quack baseline: Oink exposes an **in-place** fused op (updates `x` and `residual`).
-Quack’s fused kernel produces `out` and `residual_out` out-of-place, so by default the benchmark
-times `quack::_rmsnorm_fwd` **plus** two explicit copies (`x.copy_(out)`, `residual.copy_(residual_out)`)
-to match the in-place semantics (integration-realistic). Use `--quack-baseline kernel` to time only
-the Quack fused kernel with preallocated outputs.
+Note on the Quack baseline: Oink exposes an **in-place** fused op (updates `x`
+and `residual`). Quack’s fused kernel produces `out` and `residual_out`
+out-of-place, so by default the benchmark times `quack::_rmsnorm_fwd` **plus**
+two explicit copies (`x.copy_(out)`, `residual.copy_(residual_out)`) to match the
+in-place semantics. Use `--quack-baseline kernel` to time only the Quack fused
+kernel with preallocated outputs.
### RMSNorm backward
```bash
-python oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py --dtype bf16 --weight-dtype fp32 --quack-suite --iters 100 --warmup-ms 25 \
+python benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py --dtype bf16 --weight-dtype fp32 --quack-suite --iters 100 --warmup-ms 25 \
--csv /tmp/oink_rmsnorm_bwd_quack_suite.csv
-python oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py --dtype bf16 --weight-dtype fp32 --dsv3 --iters 100 --warmup-ms 25 \
+python benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py --dtype bf16 --weight-dtype fp32 --dsv3 --iters 100 --warmup-ms 25 \
--csv /tmp/oink_rmsnorm_bwd_dsv3.csv
```
### Softmax (forward + backward)
```bash
-python oink/benchmarks/benchmark/benchmark_softmax_sm100.py --dtype bf16 --mode fwd_bwd --quack-suite --iters 50 --warmup-ms 25 \
+python benchmarks/benchmark/benchmark_softmax_sm100.py --dtype bf16 --mode fwd_bwd --quack-suite --iters 50 --warmup-ms 25 \
--json /tmp/oink_softmax_fwd_bwd_quack_suite.json
-python oink/benchmarks/benchmark/benchmark_softmax_sm100.py --dtype bf16 --mode fwd_bwd --dsv3 --iters 50 --warmup-ms 25 \
+python benchmarks/benchmark/benchmark_softmax_sm100.py --dtype bf16 --mode fwd_bwd --dsv3 --iters 50 --warmup-ms 25 \
--json /tmp/oink_softmax_fwd_bwd_dsv3.json
```
### Cross-entropy (forward + backward)
```bash
-python oink/benchmarks/benchmark/benchmark_cross_entropy_sm100.py --dtype bf16 --mode fwd_bwd --quack-suite --iters 50 --warmup-ms 25 \
+python benchmarks/benchmark/benchmark_cross_entropy_sm100.py --dtype bf16 --mode fwd_bwd --quack-suite --iters 50 --warmup-ms 25 \
--json /tmp/oink_cross_entropy_fwd_bwd_quack_suite.json
-python oink/benchmarks/benchmark/benchmark_cross_entropy_sm100.py --dtype bf16 --mode fwd_bwd --dsv3 --iters 50 --warmup-ms 25 \
+python benchmarks/benchmark/benchmark_cross_entropy_sm100.py --dtype bf16 --mode fwd_bwd --dsv3 --iters 50 --warmup-ms 25 \
--json /tmp/oink_cross_entropy_fwd_bwd_dsv3.json
```
### LayerNorm forward
```bash
-python oink/benchmarks/benchmark/benchmark_layernorm_sm100.py --dtype bf16 --quack-suite --iters 200 --warmup-ms 25 \
+python benchmarks/benchmark/benchmark_layernorm_sm100.py --dtype bf16 --quack-suite --iters 200 --warmup-ms 25 \
--json /tmp/oink_layernorm_fwd_quack_suite.json
-python oink/benchmarks/benchmark/benchmark_layernorm_sm100.py --dtype bf16 --dsv3 --iters 200 --warmup-ms 25 \
+python benchmarks/benchmark/benchmark_layernorm_sm100.py --dtype bf16 --dsv3 --iters 200 --warmup-ms 25 \
--json /tmp/oink_layernorm_fwd_dsv3.json
```
## Notes
- These scripts intentionally avoid importing any external Oink checkout so the
- results reflect the in-tree KernelAgent Oink kernels.
-- For RMSNorm, the `rmsnorm_with_stage2` implementation is a **fallback** that
- is only used when the pointer-based fast path cannot be used (e.g. when
- `weight.dtype != x.dtype`, or when layouts/alignments are incompatible). You
+ results reflect the in-tree KernelAgent-Oink kernels.
+- `src/kernelagent_oink/blackwell/rmsnorm_with_stage2.py` is a compatibility
+ facade. The stage-2 scheduling policy lives in `_rmsnorm_impl.py`; keep the
+ facade for downstream imports.
+- For RMSNorm, the stage-2 path is a fallback used when the pointer-based fast
+ path cannot be used (for example when layouts/alignments are incompatible). You
can force it for A/B testing via `KERNELAGENT_OINK_FORCE_RMSNORM_STAGE2=1`.
diff --git a/oink/benchmarks/benchmark/bench_utils.py b/oink/benchmarks/benchmark/bench_utils.py
index dd1c2d6e..6ba7a3fa 100644
--- a/oink/benchmarks/benchmark/bench_utils.py
+++ b/oink/benchmarks/benchmark/bench_utils.py
@@ -93,7 +93,12 @@ def ensure_blackwell_arch_env(device: Optional[torch.device] = None) -> str:
def detect_hbm_peak_gbps(device: Optional[torch.device] = None) -> float:
- """Approximate HBM peak bandwidth in GB/s for roofline fractions."""
+ """Return a coarse fallback HBM peak in GB/s for benchmark JSON fields.
+
+ This helper is intentionally approximate. For published GB300/SM103
+ roofline reporting, prefer a measured roofline JSON from
+ ``benchmark_hbm_roofline_sm100.py`` and compute fractions against that run.
+ """
if device is None:
device = torch.device("cuda")
props = torch.cuda.get_device_properties(device)
@@ -144,6 +149,25 @@ def quack_suite_configs() -> List[Tuple[int, int, int]]:
return cfgs
+def dsv4_norm_configs() -> List[Tuple[int, int]]:
+ """Return DeepSeek-V4-Flash norm shapes from `inference/model.py`.
+
+ Source dimensions:
+ - hidden-state norm: N=7168
+ - q_lora norm: N=1536
+ - kv latent / per-head norm: N=512
+ """
+ Ms = [4096, 16384, 65536]
+ Ns = [7168, 1536, 512]
+ return [(m, n) for n in Ns for m in Ms]
+
+
+def dsv4_hidden_norm_configs() -> List[Tuple[int, int]]:
+ """Return DeepSeek-V4-Flash hidden-state norm shapes (N=7168)."""
+ Ms = [4096, 16384, 65536]
+ return [(m, 7168) for m in Ms]
+
+
def ensure_oink_src_on_path() -> None:
"""Make the in-repo KernelAgent Oink package importable without an editable install."""
here = os.path.dirname(os.path.abspath(__file__))
diff --git a/oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py b/oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py
index e9e5b22d..148c8e5d 100644
--- a/oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py
+++ b/oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py
@@ -58,6 +58,7 @@
collect_device_meta,
detect_hbm_peak_gbps,
do_bench_triton,
+ dsv4_hidden_norm_configs,
ensure_blackwell_arch_env,
error_stats_to_row,
ensure_oink_src_on_path,
@@ -310,6 +311,11 @@ def main() -> None:
action="store_true",
help="Run DSv3 set: M in {4096,16384,65536}, N in {6144,7168,8192}",
)
+ p.add_argument(
+ "--dsv4",
+ action="store_true",
+ help="Run DSv4 hidden-state fused-add RMSNorm set: M in {4096,16384,65536}, N=7168",
+ )
p.add_argument("--warmup-ms", type=int, default=25)
p.add_argument(
"--iters", type=int, default=200, help="rep_ms for do_bench (default: 200)"
@@ -333,7 +339,12 @@ def main() -> None:
dtype = parse_dtype(args.dtype)
meta = collect_device_meta(torch.device("cuda"))
- cfgs = dsv3_configs() if bool(args.dsv3) else [(int(args.M), int(args.N))]
+ if bool(args.dsv3):
+ cfgs = dsv3_configs()
+ elif bool(args.dsv4):
+ cfgs = dsv4_hidden_norm_configs()
+ else:
+ cfgs = [(int(args.M), int(args.N))]
rows: List[Dict[str, Any]] = []
for M, N in cfgs:
print(
diff --git a/oink/benchmarks/benchmark/benchmark_layernorm_sm100.py b/oink/benchmarks/benchmark/benchmark_layernorm_sm100.py
index 7ad7f779..569d4c76 100644
--- a/oink/benchmarks/benchmark/benchmark_layernorm_sm100.py
+++ b/oink/benchmarks/benchmark/benchmark_layernorm_sm100.py
@@ -28,6 +28,7 @@
collect_device_meta,
detect_hbm_peak_gbps,
do_bench_triton,
+ dsv4_hidden_norm_configs,
ensure_blackwell_arch_env,
error_stats_to_row,
ensure_oink_src_on_path,
@@ -355,6 +356,11 @@ def main() -> None:
action="store_true",
help="Run DSv3 set: M in {4096,16384,65536}, N in {6144,7168,8192}",
)
+ p.add_argument(
+ "--dsv4",
+ action="store_true",
+ help="Run DSv4 hidden-state LayerNorm set: M in {4096,16384,65536}, N=7168",
+ )
p.add_argument(
"--skip-verify",
action="store_true",
@@ -369,6 +375,8 @@ def main() -> None:
cfgs = [(bs * sl, hidden) for (bs, sl, hidden) in quack_suite_configs()]
elif args.dsv3:
cfgs = dsv3_configs()
+ elif args.dsv4:
+ cfgs = dsv4_hidden_norm_configs()
else:
cfgs = parse_configs(args.configs)
diff --git a/oink/benchmarks/benchmark/benchmark_paulius_rmsnorm.py b/oink/benchmarks/benchmark/benchmark_paulius_rmsnorm.py
deleted file mode 100644
index b2bf75cd..00000000
--- a/oink/benchmarks/benchmark/benchmark_paulius_rmsnorm.py
+++ /dev/null
@@ -1,242 +0,0 @@
-from __future__ import annotations
-
-import argparse
-import os
-import re
-import subprocess
-from pathlib import Path
-from typing import Any, Dict, List, Tuple
-
-import torch
-
-from bench_utils import collect_device_meta, detect_hbm_peak_gbps, write_csv, write_json
-
-
-def _bench_oink_smallm_noweight(M: int, N: int) -> float:
- import sys
-
- from triton.testing import do_bench_cudagraph
-
- repo_src = Path(__file__).resolve().parents[2] / "src"
- if str(repo_src) not in sys.path:
- sys.path.insert(0, str(repo_src))
- from kernelagent_oink.blackwell import _rmsnorm_impl as impl
-
- x = torch.randn(M, N, device="cuda", dtype=torch.bfloat16)
- out = torch.empty_like(x)
- return float(
- do_bench_cudagraph(
- lambda: impl._rmsnorm_forward_ptr_into(
- x, None, None, None, out, None, None, 1e-6
- ),
- rep=100,
- return_mode="mean",
- )
- )
-
-
-def bytes_io_model_fwd(M: int, N: int, dtype: torch.dtype) -> int:
- elem = torch.tensor(0, dtype=dtype).element_size()
- return int(2 * M * N * elem)
-
-
-def _cuda_13_nvcc() -> Path:
- nvcc = Path("/usr/local/cuda-13.0/bin/nvcc")
- if not nvcc.is_file():
- raise FileNotFoundError(f"CUDA 13.0 nvcc not found at {nvcc}")
- return nvcc
-
-
-def _build_paulius_binary(src_dir: Path) -> Path:
- nvcc = _cuda_13_nvcc()
- out = src_dir / "r.out"
- cmd = [
- str(nvcc),
- "-arch=sm_100",
- "-Xptxas",
- "-v",
- "-O3",
- "RmsNorm.cu",
- "-I../../../",
- "-o",
- str(out),
- "-lnvidia-ml",
- ]
- env = os.environ.copy()
- env["CUDA_HOME"] = "/usr/local/cuda-13.0"
- env["PATH"] = f"/usr/local/cuda-13.0/bin:{env.get('PATH', '')}"
- subprocess.run(
- cmd,
- cwd=src_dir,
- env=env,
- check=True,
- stdout=subprocess.DEVNULL,
- stderr=subprocess.DEVNULL,
- )
- return out
-
-
-def _parse_paulius_output(text: str) -> List[Tuple[float, float]]:
- rows: List[Tuple[float, float]] = []
- pattern = re.compile(r"BF16\s+\d+:\s+([0-9.eE+-]+)\s+ms\s+([0-9.eE+-]+)\s+GB/s")
- for line in text.splitlines():
- match = pattern.search(line)
- if match is None:
- continue
- rows.append((float(match.group(1)), float(match.group(2))))
- return rows
-
-
-def _run_paulius(
- binary: Path,
- *,
- M: int,
- N: int,
- cta_dim_y: int,
- warmup_reps: int,
- timing_reps: int,
- gpu_id: int,
-) -> Tuple[float, float, Dict[str, Any]]:
- cmd = [
- str(binary),
- str(M),
- str(N),
- str(cta_dim_y),
- str(warmup_reps),
- str(timing_reps),
- str(gpu_id),
- "0",
- "5",
- "1",
- ]
- proc = subprocess.run(
- cmd,
- cwd=binary.parent,
- text=True,
- capture_output=True,
- check=True,
- )
- parsed = _parse_paulius_output(proc.stdout)
- if not parsed:
- raise RuntimeError(
- f"Failed to parse Paulius output:\n{proc.stdout}\n{proc.stderr}"
- )
- ms, gbps = min(parsed, key=lambda row: row[0])
- return ms, gbps, {"raw_stdout": proc.stdout, "raw_stderr": proc.stderr}
-
-
-def main() -> None:
- if not torch.cuda.is_available():
- raise SystemExit("CUDA not available")
-
- torch.cuda.set_device(0)
- device = torch.device("cuda")
- props = torch.cuda.get_device_properties(device)
- sm = props.major * 10 + props.minor
- print(f"Running on {torch.cuda.get_device_name(device)} (SM{sm})")
-
- p = argparse.ArgumentParser()
- p.add_argument(
- "--paulius-dir",
- type=str,
- default=os.path.expanduser("~/fbsource/fbcode/scripts/paulius/rmsnorm"),
- )
- p.add_argument("--gpu-id", type=int, default=0)
- p.add_argument("--warmup-reps", type=int, default=10)
- p.add_argument("--timing-reps", type=int, default=100)
- p.add_argument("--configs", type=str, default="4096x4096,65536x4096")
- p.add_argument("--csv", type=str, default=None)
- p.add_argument("--json", type=str, default=None)
- args = p.parse_args()
-
- src_dir = Path(args.paulius_dir)
- binary = _build_paulius_binary(src_dir)
-
- cfgs: List[Tuple[int, int]] = []
- for part in args.configs.split(","):
- m, n = part.lower().split("x")
- cfgs.append((int(m), int(n)))
-
- meta = collect_device_meta(device)
- hbm_peak = detect_hbm_peak_gbps(device)
- rows_out: List[Dict[str, Any]] = []
- for M, N in cfgs:
- if N != 4096:
- raise SystemExit("Paulius benchmark only supports N=4096")
- best_ms = float("inf")
- best_gbps = 0.0
- best_cta_dim_y = -1
- debug_runs: List[Dict[str, Any]] = []
- for cta_dim_y in (1, 2, 4, 8):
- ms, gbps, debug = _run_paulius(
- binary,
- M=M,
- N=N,
- cta_dim_y=cta_dim_y,
- warmup_reps=int(args.warmup_reps),
- timing_reps=int(args.timing_reps),
- gpu_id=int(args.gpu_id),
- )
- debug_runs.append({"cta_dim_y": cta_dim_y, "ms": ms, "gbps": gbps})
- if ms < best_ms:
- best_ms = ms
- best_gbps = gbps
- best_cta_dim_y = cta_dim_y
- row: Dict[str, Any] = {
- "M": M,
- "N": N,
- "dtype": "bf16",
- "paulius_ms": best_ms,
- "paulius_gbps": best_gbps,
- "paulius_tbps": best_gbps / 1000.0,
- "paulius_hbm_frac": best_gbps / hbm_peak,
- "best_cta_dim_y": best_cta_dim_y,
- "io_model_bytes": bytes_io_model_fwd(M, N, torch.bfloat16),
- "cta_dim_y_candidates": debug_runs,
- }
- if M == 4096:
- oink_ms = _bench_oink_smallm_noweight(M, N)
- oink_gbps = (
- bytes_io_model_fwd(M, N, torch.bfloat16) / (oink_ms * 1e-3) / 1e9
- )
- row.update(
- {
- "oink_kernel_ms": oink_ms,
- "oink_kernel_tbps": oink_gbps / 1000.0,
- "oink_speedup_vs_paulius": best_ms / oink_ms,
- }
- )
- rows_out.append(row)
-
- if args.csv is not None:
- write_csv(args.csv, rows_out)
- if args.json is not None:
- write_json(
- args.json,
- meta,
- rows_out,
- extra={
- "method": "Paulius CUDA benchmark binary",
- "warmup_reps": int(args.warmup_reps),
- "timing_reps": int(args.timing_reps),
- "paulius_dir": str(src_dir),
- },
- )
-
- print("\nSummary:")
- print(
- f"{'M':>14} {'N':>14} {'paulius_ms':>14} {'paulius_tbps':>14}"
- f" {'ctaDimY':>14} {'oink_ms':>14} {'oink/paulius':>14}"
- )
- for r in rows_out:
- oink_ms = float(r.get("oink_kernel_ms", float("nan")))
- speedup = float(r.get("oink_speedup_vs_paulius", float("nan")))
- print(
- f"{int(r['M']):>14} {int(r['N']):>14} {float(r['paulius_ms']):14.4f}"
- f" {float(r['paulius_tbps']):14.4f} {int(r['best_cta_dim_y']):>14}"
- f" {oink_ms:14.4f} {speedup:14.4f}"
- )
-
-
-if __name__ == "__main__":
- main()
diff --git a/oink/benchmarks/benchmark/benchmark_rmsnorm_all.py b/oink/benchmarks/benchmark/benchmark_rmsnorm_all.py
deleted file mode 100644
index 6d6eb3df..00000000
--- a/oink/benchmarks/benchmark/benchmark_rmsnorm_all.py
+++ /dev/null
@@ -1,325 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""
-Benchmark aten vs quack vs oink RMSNorm: normal dispatch + CUDA graph.
-
-All calls go through ``torch.ops.aten._fused_rms_norm``.
-Quack is registered via ``torch._native`` (quack PR pattern).
-Oink is registered via ``kernelagent_oink.register_all_kernels()``.
-
-Produces four tables:
- - Forward (normal dispatch)
- - Forward + Backward (normal dispatch)
- - Forward (CUDA graph)
- - Forward + Backward (CUDA graph)
-
-Usage::
-
- python oink/benchmarks/benchmark/benchmark_rmsnorm_all.py
-"""
-
-from __future__ import annotations
-
-import json
-import os
-import subprocess
-import sys
-import tempfile
-
-os.environ.setdefault("TORCH_NATIVE_SKIP_VERSION_CHECK", "1")
-
-
-# ---------------------------------------------------------------------------
-# Worker code: runs in a subprocess per mode to avoid cross-contamination.
-# ---------------------------------------------------------------------------
-
-WORKER_CODE = r"""
-import json, os, sys
-os.environ.setdefault("TORCH_NATIVE_SKIP_VERSION_CHECK", "1")
-
-import torch
-from triton.testing import do_bench
-
-DTYPE = torch.bfloat16
-
-def bench_normal(fn, warmup=50, rep=200):
- return do_bench(fn, warmup=warmup, rep=rep, return_mode="median")
-
-def bench_cudagraph(fn, warmup=50, rep=200):
- for _ in range(warmup):
- fn()
- torch.cuda.synchronize()
- g = torch.cuda.CUDAGraph()
- with torch.cuda.graph(g):
- fn()
- torch.cuda.synchronize()
- return do_bench(lambda: g.replay(), warmup=10, rep=rep, return_mode="median")
-
-mode = sys.argv[1]
-shapes_json = sys.argv[2]
-SHAPES = json.loads(shapes_json)
-
-if mode == "oink":
- import kernelagent_oink
- kernelagent_oink.register_all_kernels(force=True)
-
-# Warm up
-for M, N in SHAPES:
- x = torch.randn(M, N, dtype=DTYPE, device="cuda")
- w = torch.randn(N, dtype=DTYPE, device="cuda")
- torch.ops.aten._fused_rms_norm(x, [N], w, 1e-5)
-torch.cuda.synchronize()
-
-results = {}
-for M, N in SHAPES:
- x = torch.randn(M, N, dtype=DTYPE, device="cuda", requires_grad=True)
- w = torch.randn(N, dtype=DTYPE, device="cuda", requires_grad=True)
- grad = torch.randn(M, N, dtype=DTYPE, device="cuda")
-
- # Forward (normal)
- def fn_fwd(x=x, w=w, N=N):
- return torch.ops.aten._fused_rms_norm(x, [N], w, 1e-5)
- fwd_ms = bench_normal(fn_fwd)
-
- # Forward + Backward (normal)
- x_ = x.detach().requires_grad_(True)
- w_ = w.detach().requires_grad_(True)
- def fn_fwdbwd(x_=x_, w_=w_, N=N, grad=grad):
- y, _ = torch.ops.aten._fused_rms_norm(x_, [N], w_, 1e-5)
- y.backward(grad)
- fwdbwd_ms = bench_normal(fn_fwdbwd)
-
- # Forward (CUDA graph)
- x_g = torch.randn(M, N, dtype=DTYPE, device="cuda")
- w_g = torch.randn(N, dtype=DTYPE, device="cuda")
- def fn_fwd_g(x=x_g, w=w_g, N=N):
- return torch.ops.aten._fused_rms_norm(x, [N], w, 1e-5)
- try:
- fwd_graph_ms = bench_cudagraph(fn_fwd_g)
- except Exception:
- fwd_graph_ms = -1.0
-
- # Forward + Backward (CUDA graph)
- x_gb = torch.randn(M, N, dtype=DTYPE, device="cuda", requires_grad=True)
- w_gb = torch.randn(N, dtype=DTYPE, device="cuda", requires_grad=True)
- grad_gb = torch.randn(M, N, dtype=DTYPE, device="cuda")
- def fn_fwdbwd_g(x=x_gb, w=w_gb, N=N, grad=grad_gb):
- y, _ = torch.ops.aten._fused_rms_norm(x, [N], w, 1e-5)
- y.backward(grad)
- try:
- fwdbwd_graph_ms = bench_cudagraph(fn_fwdbwd_g)
- except Exception:
- fwdbwd_graph_ms = -1.0
-
- results[f"{M}x{N}"] = {
- "fwd": fwd_ms,
- "fwdbwd": fwdbwd_ms,
- "fwd_graph": fwd_graph_ms,
- "fwdbwd_graph": fwdbwd_graph_ms,
- }
-
-print(json.dumps({"mode": mode, "results": results}))
-"""
-
-
-# ---------------------------------------------------------------------------
-# Main: orchestrates subprocesses and prints tables.
-# ---------------------------------------------------------------------------
-
-SHAPES = [
- [1, 4096],
- [1, 8192],
- [32, 4096],
- [32, 8192],
- [256, 4096],
- [256, 8192],
- [1024, 4096],
- [1024, 8192],
- [4096, 4096],
- [4096, 8192],
- [16384, 4096],
- [16384, 8192],
- [65536, 4096],
- [65536, 8192],
-]
-
-COL_W = { # column widths
- "shape": 14,
- "ms": 10,
- "ratio": 8,
-}
-
-
-def find_norm_dir():
- import torch
- from pathlib import Path
-
- d = Path(torch.__file__).parent / "_native" / "ops" / "norm"
- return str(d) if d.is_dir() else None
-
-
-def run_mode(mode, norm_dir, shapes):
- init_file = os.path.join(norm_dir, "__init__.py")
-
- if mode in ("aten", "oink"):
- with open(init_file, "w") as f:
- f.write("")
- elif mode == "quack":
- with open(init_file, "w") as f:
- f.write("from . import rmsnorm_impl # noqa: F401\n")
-
- with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as tmp:
- tmp.write(WORKER_CODE)
- tmp_path = tmp.name
-
- try:
- result = subprocess.run(
- [sys.executable, tmp_path, mode, json.dumps(shapes)],
- capture_output=True,
- text=True,
- timeout=600,
- )
- if result.returncode != 0:
- print(f" [{mode}] FAILED: {result.stderr[-300:]}", file=sys.stderr)
- return None
- return json.loads(result.stdout.strip())["results"]
- finally:
- os.unlink(tmp_path)
-
-
-def _fmt_ms(v):
- return f"{v:>{COL_W['ms']}.4f}" if v > 0 else "FAIL".rjust(COL_W["ms"])
-
-
-def _fmt_ratio(n, d):
- if d <= 0 or n <= 0:
- return "N/A".rjust(COL_W["ratio"])
- return f"{f'{n / d:.2f}x':>{COL_W['ratio']}}"
-
-
-def print_table(title, subtitle, aten, quack, oink, key):
- sw, mw, rw = COL_W["shape"], COL_W["ms"], COL_W["ratio"]
- w = [sw, mw, mw, mw, rw, rw, rw]
-
- def hr(left, mid, right):
- return left + mid.join("─" * (c + 2) for c in w) + right
-
- hdr = (
- f"│ {'Shape (M,N)':^{sw}} "
- f"│ {'Aten (ms)':^{mw}} "
- f"│ {'Quack (ms)':^{mw}} "
- f"│ {'Oink (ms)':^{mw}} "
- f"│ {'Q/A':^{rw}} "
- f"│ {'O/A':^{rw}} "
- f"│ {'O/Q':^{rw}} │"
- )
-
- print()
- print(f" {title}")
- print(f" {subtitle}")
- print(hr("┌", "┬", "┐"))
- print(hdr)
- print(hr("├", "┼", "┤"))
-
- for shape in aten:
- M, N = shape.split("x")
- a, q, o = aten[shape][key], quack[shape][key], oink[shape][key]
- row = (
- f"│ {f'({M},{N})':>{sw}} "
- f"│ {_fmt_ms(a)} "
- f"│ {_fmt_ms(q)} "
- f"│ {_fmt_ms(o)} "
- f"│ {_fmt_ratio(a, q)} "
- f"│ {_fmt_ratio(a, o)} "
- f"│ {_fmt_ratio(q, o)} │"
- )
- print(row)
-
- print(hr("└", "┴", "┘"))
-
-
-def main():
- import torch
-
- print("=" * 72)
- print(" RMSNorm Kernel Benchmark: Aten vs Quack vs Oink")
- print("=" * 72)
- print(f" Device : {torch.cuda.get_device_name(0)}")
- print(f" Torch : {torch.__version__}")
- print(" Dtype : bfloat16")
- print(" Quack : registered via torch._native (quack PR)")
- print(" Oink : registered via kernelagent_oink.register_all_kernels()")
- print(" Bench : triton.testing.do_bench (median, 200 reps)")
-
- norm_dir = find_norm_dir()
- if norm_dir is None:
- print("ERROR: torch._native/ops/norm/ not found.", file=sys.stderr)
- sys.exit(1)
-
- print()
- print("Running aten...")
- aten = run_mode("aten", norm_dir, SHAPES)
- print("Running quack...")
- quack = run_mode("quack", norm_dir, SHAPES)
- print("Running oink...")
- oink = run_mode("oink", norm_dir, SHAPES)
-
- # Restore
- with open(os.path.join(norm_dir, "__init__.py"), "w") as f:
- f.write("from . import rmsnorm_impl # noqa: F401\n")
-
- if not all([aten, quack, oink]):
- print("ERROR: one or more modes failed.", file=sys.stderr)
- sys.exit(1)
-
- print_table(
- "Forward — Normal Dispatch",
- "Standard Python dispatch through torch.ops.aten._fused_rms_norm.",
- aten,
- quack,
- oink,
- "fwd",
- )
- print_table(
- "Forward + Backward — Normal Dispatch",
- "Fwd + autograd backward, standard Python dispatch.",
- aten,
- quack,
- oink,
- "fwdbwd",
- )
- print_table(
- "Forward — CUDA Graph (zero Python overhead)",
- "Kernel captured once, replayed without re-entering Python.",
- aten,
- quack,
- oink,
- "fwd_graph",
- )
- print_table(
- "Forward + Backward — CUDA Graph (zero Python overhead)",
- "Fwd + bwd captured once, replayed without re-entering Python.",
- aten,
- quack,
- oink,
- "fwdbwd_graph",
- )
-
- print()
- print("Done.")
-
-
-if __name__ == "__main__":
- main()
diff --git a/oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py b/oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py
index e137de7c..10385972 100644
--- a/oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py
+++ b/oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py
@@ -18,7 +18,7 @@
import csv
import os
from dataclasses import dataclass
-from typing import List, Optional, Tuple
+from typing import List, Tuple
import torch
from triton.testing import do_bench as triton_do_bench
@@ -29,6 +29,8 @@
from bench_utils import ( # noqa: E402
ErrorStatsAccumulator,
collect_device_meta,
+ detect_hbm_peak_gbps,
+ dsv4_norm_configs,
ensure_blackwell_arch_env,
ensure_oink_src_on_path,
error_stats_to_row,
@@ -55,17 +57,6 @@
}
-def detect_hbm_peak_gbps(device: Optional[torch.device] = None) -> float:
- """Approximate HBM peak bandwidth in GB/s for roofline fractions."""
- if device is None:
- device = torch.device("cuda")
- props = torch.cuda.get_device_properties(device)
- sm = props.major * 10 + props.minor
- if sm >= 100:
- return 8000.0
- return 2000.0
-
-
@dataclass
class Result:
ms: float
@@ -360,6 +351,11 @@ def main() -> None:
action="store_true",
help="Run DSv3 set: M in {4096,16384,65536}, N in {6144,7168,8192}",
)
+ p.add_argument(
+ "--dsv4",
+ action="store_true",
+ help="Run DSv4 norm set: M in {4096,16384,65536}, N in {7168,1536,512}",
+ )
p.add_argument(
"--skip-verify",
action="store_true",
@@ -378,6 +374,8 @@ def main() -> None:
cfgs = [(bs * sl, hidden) for (bs, sl, hidden) in quack_suite_configs()]
elif args.dsv3:
cfgs = dsv3_configs()
+ elif args.dsv4:
+ cfgs = dsv4_norm_configs()
else:
cfgs = parse_configs(args.configs)
diff --git a/oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py b/oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py
index 809a4756..fccafedf 100644
--- a/oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py
+++ b/oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py
@@ -28,6 +28,7 @@
collect_device_meta,
detect_hbm_peak_gbps,
do_bench_triton,
+ dsv4_norm_configs,
ensure_blackwell_arch_env,
error_stats_to_row,
ensure_oink_src_on_path,
@@ -285,6 +286,11 @@ def main() -> None:
action="store_true",
help="Run DSv3 set: M in {4096,16384,65536}, N in {6144,7168,8192}",
)
+ p.add_argument(
+ "--dsv4",
+ action="store_true",
+ help="Run DSv4 norm set: M in {4096,16384,65536}, N in {7168,1536,512}",
+ )
p.add_argument(
"--skip-verify",
action="store_true",
@@ -303,6 +309,8 @@ def main() -> None:
cfgs = [(bs * sl, hidden) for (bs, sl, hidden) in quack_suite_configs()]
elif args.dsv3:
cfgs = dsv3_configs()
+ elif args.dsv4:
+ cfgs = dsv4_norm_configs()
else:
cfgs = parse_configs(args.configs)
diff --git a/oink/benchmarks/media/sm103_bf16_oink_vs_quack_with_layernorm.svg b/oink/benchmarks/media/sm103_bf16_oink_vs_quack_with_layernorm.svg
new file mode 100644
index 00000000..0e161033
--- /dev/null
+++ b/oink/benchmarks/media/sm103_bf16_oink_vs_quack_with_layernorm.svg
@@ -0,0 +1,2627 @@
+
+
+
diff --git a/oink/benchmarks/readme/plot_quack_style_svg.py b/oink/benchmarks/readme/plot_quack_style_svg.py
index 88eebdf3..5e0224a8 100644
--- a/oink/benchmarks/readme/plot_quack_style_svg.py
+++ b/oink/benchmarks/readme/plot_quack_style_svg.py
@@ -13,8 +13,8 @@
# limitations under the License.
"""
-Generate Quack-style SVG performance plots (Oink vs Quack) from the SM100 suite
-JSON artifacts under `/tmp/kernelagent_oink_sm100_suite_{bf16,fp16}`.
+Generate Quack-style SVG performance plots (Oink vs Quack) from SM10x suite
+JSON artifacts under a suite output directory.
The intent is to match Quack's README visual style:
- 3 horizontal panels (suite-dependent):
@@ -240,12 +240,13 @@ def _plot(
label="Quack",
)
if roofline_gbps is not None:
+ roof_label = f"HBM peak (measured {roofline_gbps / 1000.0:.3f} TB/s)"
ax.axhline(
roofline_gbps,
color=COLOR_ROOF,
linewidth=3,
linestyle=(0, (4, 6)),
- label="HBM peak (measured)" if ax is axes[0] else None,
+ label=roof_label if ax is axes[0] else None,
)
max_y = max(max_y, float(roofline_gbps))
@@ -389,7 +390,19 @@ def main() -> None:
"--roofline-json",
type=str,
default=None,
- help="Optional /tmp/hbm_roofline_sm100_*.json path",
+ help="Optional measured roofline JSON path from benchmark_hbm_roofline_sm100.py",
+ )
+ p.add_argument(
+ "--roofline-gbps",
+ type=float,
+ default=None,
+ help="Optional measured roofline in GB/s (mutually exclusive with --roofline-json).",
+ )
+ p.add_argument(
+ "--arch-label",
+ type=str,
+ default="SM100",
+ help="Architecture label used in auto-generated titles, e.g. 'SM103 / GB300'.",
)
p.add_argument("--out", type=str, required=True, help="Output SVG path")
p.add_argument(
@@ -401,8 +414,12 @@ def main() -> None:
if not os.path.isdir(in_dir):
raise SystemExit(f"--in-dir is not a directory: {in_dir}")
+ if args.roofline_json is not None and args.roofline_gbps is not None:
+ raise SystemExit("Use only one of --roofline-json or --roofline-gbps.")
roofline_gbps = (
- _read_roofline_gbps(args.roofline_json) if args.roofline_json else None
+ float(args.roofline_gbps)
+ if args.roofline_gbps is not None
+ else (_read_roofline_gbps(args.roofline_json) if args.roofline_json else None)
)
panel_files = list(_panel_files_for_suite(str(args.suite)))
@@ -451,10 +468,11 @@ def main() -> None:
if (args.suite == "quack_suite" and args.include_layernorm)
else ""
)
+ arch_label = str(args.arch_label)
if args.suite == "dsv3_cross_entropy":
- title = f"SM100 {dtype.upper()} — {suite_name}{suffix}"
+ title = f"{arch_label} {dtype.upper()} — {suite_name}{suffix}"
else:
- title = f"SM100 {dtype.upper()} Kernel Benchmarks (Oink vs Quack) — {suite_name}{suffix}"
+ title = f"{arch_label} {dtype.upper()} Kernel Benchmarks (Oink vs Quack) — {suite_name}{suffix}"
_plot(
panels=panels,
diff --git a/oink/benchmarks/readme/run_sm100_suite.py b/oink/benchmarks/readme/run_sm100_suite.py
index 05bd2116..71920e38 100644
--- a/oink/benchmarks/readme/run_sm100_suite.py
+++ b/oink/benchmarks/readme/run_sm100_suite.py
@@ -56,6 +56,11 @@ def main() -> None:
action="store_true",
help="Skip correctness checks (Oink/Quack vs PyTorch / pure-PyTorch references)",
)
+ p.add_argument(
+ "--include-dsv4",
+ action="store_true",
+ help="Also run DeepSeek-V4-Flash norm workloads (RMSNorm N={7168,1536,512}; LayerNorm/fused-add N=7168).",
+ )
p.add_argument(
"--dry-run", action="store_true", help="Print commands without executing them"
)
@@ -83,7 +88,115 @@ def script(name: str) -> str:
if args.skip_verify:
common = [*common, "--skip-verify"]
- runs: List[Tuple[str, List[str]]] = [
+ runs: List[Tuple[str, List[str]]] = []
+
+ if args.include_dsv4:
+ runs.extend(
+ [
+ (
+ "rmsnorm_fwd_dsv4_wfp32",
+ [
+ py,
+ script("benchmark_rmsnorm_sm100.py"),
+ *common,
+ "--weight-dtype",
+ "fp32",
+ "--dsv4",
+ "--iters",
+ "200",
+ "--warmup-ms",
+ "25",
+ "--json",
+ os.path.join(out_dir, "rmsnorm_fwd_dsv4_wfp32.json"),
+ ],
+ ),
+ (
+ "rmsnorm_fwd_dsv4_wsame",
+ [
+ py,
+ script("benchmark_rmsnorm_sm100.py"),
+ *common,
+ "--weight-dtype",
+ "same",
+ "--dsv4",
+ "--iters",
+ "200",
+ "--warmup-ms",
+ "25",
+ "--json",
+ os.path.join(out_dir, "rmsnorm_fwd_dsv4_wsame.json"),
+ ],
+ ),
+ (
+ "rmsnorm_bwd_dsv4_wfp32",
+ [
+ py,
+ script("benchmark_rmsnorm_bwd_sm100.py"),
+ *common,
+ "--weight-dtype",
+ "fp32",
+ "--dsv4",
+ "--iters",
+ "100",
+ "--warmup-ms",
+ "25",
+ "--json",
+ os.path.join(out_dir, "rmsnorm_bwd_dsv4_wfp32.json"),
+ ],
+ ),
+ (
+ "rmsnorm_bwd_dsv4_wsame",
+ [
+ py,
+ script("benchmark_rmsnorm_bwd_sm100.py"),
+ *common,
+ "--weight-dtype",
+ "same",
+ "--dsv4",
+ "--iters",
+ "100",
+ "--warmup-ms",
+ "25",
+ "--json",
+ os.path.join(out_dir, "rmsnorm_bwd_dsv4_wsame.json"),
+ ],
+ ),
+ (
+ "fused_add_rmsnorm_dsv4",
+ [
+ py,
+ script("benchmark_fused_add_rmsnorm_sm100.py"),
+ *common,
+ "--dsv4",
+ "--quack-baseline",
+ "kernel_inplace",
+ "--iters",
+ "200",
+ "--warmup-ms",
+ "25",
+ "--json",
+ os.path.join(out_dir, "fused_add_rmsnorm_dsv4.json"),
+ ],
+ ),
+ (
+ "layernorm_fwd_dsv4",
+ [
+ py,
+ script("benchmark_layernorm_sm100.py"),
+ *common,
+ "--dsv4",
+ "--iters",
+ "200",
+ "--warmup-ms",
+ "25",
+ "--json",
+ os.path.join(out_dir, "layernorm_fwd_dsv4.json"),
+ ],
+ ),
+ ]
+ )
+
+ runs.extend([
(
"rmsnorm_fwd_quack_suite_wfp32",
[
@@ -336,7 +449,7 @@ def script(name: str) -> str:
os.path.join(out_dir, "layernorm_fwd_dsv3.json"),
],
),
- ]
+ ])
print(f"Writing results to: {out_dir}", flush=True)
for name, cmd in runs:
diff --git a/oink/benchmarks/readme/summarize_results.py b/oink/benchmarks/readme/summarize_results.py
index 684694d6..efc92e4e 100644
--- a/oink/benchmarks/readme/summarize_results.py
+++ b/oink/benchmarks/readme/summarize_results.py
@@ -232,7 +232,7 @@ def main() -> None:
raise SystemExit(f"No .json files found under: {in_dir}")
out_parts: List[str] = []
- out_parts.append("# KernelAgent-Oink SM100 Benchmark Summary")
+ out_parts.append("# KernelAgent-Oink SM10x Benchmark Summary")
out_parts.append("")
out_parts.append(f"Input directory: `{in_dir}`")
out_parts.append("")
diff --git a/oink/pyproject.toml b/oink/pyproject.toml
index e0f19270..9dedd92f 100644
--- a/oink/pyproject.toml
+++ b/oink/pyproject.toml
@@ -5,12 +5,12 @@ build-backend = "setuptools.build_meta"
[project]
name = "kernelagent-oink"
version = "0.1.0"
-description = "CuTeDSL kernels for Blackwell (SM100), shipped as a vLLM plugin"
+description = "CuTeDSL kernels for Blackwell SM10x (SM100-SM103), shipped as a vLLM plugin"
readme = "README.md"
requires-python = ">=3.10"
license = {text = "Apache-2.0"}
authors = [{name = "PyTorch Labs"}]
-keywords = ["cuda", "cutlass", "cute", "cutedsl", "blackwell", "sm100", "vllm"]
+keywords = ["cuda", "cutlass", "cute", "cutedsl", "blackwell", "sm100", "sm103", "gb300", "vllm"]
classifiers = [
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3",
@@ -27,7 +27,7 @@ classifiers = [
# We intentionally do NOT depend on `torch` here because vLLM already pins and
# provides a compatible PyTorch build.
dependencies = [
- "nvidia-cutlass-dsl>=4.2.1",
+ "nvidia-cutlass-dsl>=4.4.2",
"cuda-python",
]
diff --git a/oink/src/kernelagent_oink/__init__.py b/oink/src/kernelagent_oink/__init__.py
index d61b60e4..4fe33b1c 100644
--- a/oink/src/kernelagent_oink/__init__.py
+++ b/oink/src/kernelagent_oink/__init__.py
@@ -13,7 +13,7 @@
# limitations under the License.
"""
-KernelAgent-Oink: SM100 CuTeDSL kernels + optional vLLM plugin.
+KernelAgent-Oink: Blackwell SM10x CuTeDSL kernels + optional vLLM plugin.
This package can be loaded as a vLLM "general plugin" (entrypoint group
`vllm.general_plugins`). In that mode it registers Oink custom ops only when
@@ -135,7 +135,7 @@ def register(*, force: bool = False) -> None:
def register_all_kernels(*, force: bool = False) -> None:
"""Override aten ops with Oink's kernels.
- Checks CUDA/SM100/deps, sets up the CuTeDSL environment, then overrides
+ Checks CUDA/Blackwell SM10x/deps, sets up the CuTeDSL environment, then overrides
``aten::_fused_rms_norm`` and ``aten::_fused_rms_norm_backward`` on CUDA.
Does NOT register ``torch.ops.oink.*`` custom ops — use :func:`register`
diff --git a/oink/src/kernelagent_oink/blackwell/_cutedsl_cache.py b/oink/src/kernelagent_oink/blackwell/_cutedsl_cache.py
new file mode 100644
index 00000000..2e2536d7
--- /dev/null
+++ b/oink/src/kernelagent_oink/blackwell/_cutedsl_cache.py
@@ -0,0 +1,49 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""CuTeDSL cache setup shared by Blackwell kernel modules.
+
+CuTeDSL cache bytecode is version-sensitive. The default global cache
+(``/tmp/$USER/cutlass_python_cache``) can be shared across environments with
+incompatible ``nvidia-cutlass-dsl`` versions, producing noisy warnings and
+lost cache reuse. Call this helper before importing ``cutlass`` in modules
+that compile CuTeDSL kernels.
+"""
+
+from __future__ import annotations
+
+import importlib.metadata
+import os
+import re
+
+
+def ensure_versioned_cutedsl_cache_dir() -> None:
+ """Set a version-scoped CuTeDSL cache directory when the user did not.
+
+ The path format intentionally matches the historical per-module logic:
+ ``$TMPDIR/$USER/cutlass_python_cache_``.
+ If ``CUTE_DSL_CACHE_DIR`` is already set, leave it untouched.
+ """
+ if "CUTE_DSL_CACHE_DIR" in os.environ:
+ return
+ try:
+ dsl_ver = importlib.metadata.version("nvidia-cutlass-dsl")
+ except Exception:
+ dsl_ver = "unknown"
+ dsl_ver = re.sub(r"[^0-9A-Za-z]+", "_", dsl_ver)
+ user = os.environ.get("USER") or os.environ.get("USERNAME") or "user"
+ tmp = os.environ.get("TMPDIR") or "/tmp"
+ os.environ["CUTE_DSL_CACHE_DIR"] = os.path.join(
+ tmp, user, f"cutlass_python_cache_{dsl_ver}"
+ )
diff --git a/oink/src/kernelagent_oink/blackwell/_rmsnorm_impl.py b/oink/src/kernelagent_oink/blackwell/_rmsnorm_impl.py
index cbea22d8..3af5ea3b 100644
--- a/oink/src/kernelagent_oink/blackwell/_rmsnorm_impl.py
+++ b/oink/src/kernelagent_oink/blackwell/_rmsnorm_impl.py
@@ -16,25 +16,14 @@
from __future__ import annotations
-import importlib.metadata
import os
-import re
from dataclasses import dataclass, replace
# Vendored/adapted from Quack's SM100 RMSNorm with Oink-specific B200 tuning.
-# CuTeDSL cache bytecode is version-sensitive, so isolate the default cache.
-if "CUTE_DSL_CACHE_DIR" not in os.environ:
- try:
- _dsl_ver = importlib.metadata.version("nvidia-cutlass-dsl")
- except Exception:
- _dsl_ver = "unknown"
- _dsl_ver = re.sub(r"[^0-9A-Za-z]+", "_", _dsl_ver)
- _user = os.environ.get("USER") or os.environ.get("USERNAME") or "user"
- _tmp = os.environ.get("TMPDIR") or "/tmp"
- os.environ["CUTE_DSL_CACHE_DIR"] = os.path.join(
- _tmp, _user, f"cutlass_python_cache_{_dsl_ver}"
- )
+from kernelagent_oink.blackwell._cutedsl_cache import ensure_versioned_cutedsl_cache_dir
+
+ensure_versioned_cutedsl_cache_dir()
try:
import cutlass # type: ignore # noqa: F401
@@ -182,7 +171,16 @@ def _resolve_forward_launch_config(
weight_dtype: type[cutlass.Numeric] | None,
aligned_tensors: tuple[Tensor | None, ...],
) -> _ForwardLaunchConfig:
- direct_gmem_default = bool(dtype.width == 16 and N in {128, 4096, 6144, 7168, 8192})
+ direct_gmem_default = bool(
+ dtype.width == 16 and N in {128, 512, 4096, 6144, 7168, 8192}
+ )
+ if (
+ dtype.width == 16
+ and N == 1536
+ and weight_dtype is not None
+ and weight_dtype.width == 16
+ ):
+ direct_gmem_default = True
if weight_dtype is not None and weight_dtype.width == 32 and N == 7168:
direct_gmem_default = False
direct_gmem = _direct_gmem_from_policy(default=direct_gmem_default)
@@ -197,7 +195,7 @@ def _resolve_forward_launch_config(
default_copy_bits = 256 if can_use_256 else 128
if dtype.width == 16 and N == 128:
default_copy_bits = 128
- if dtype.width == 16 and N == 4096:
+ if dtype.width == 16 and N in {512, 1536, 4096}:
default_copy_bits = 128
if dtype.width == 16 and weight_dtype is not None and weight_dtype.width == 32:
default_copy_bits = 128 if N == 4096 else 64
@@ -205,6 +203,12 @@ def _resolve_forward_launch_config(
copy_bits = _copy_bits_from_policy(
default=default_copy_bits, can_use_256=can_use_256
)
+ # cp.async supports at most 128 bits per instruction. The copy atom clamps
+ # async copies to 128b, so keep the TV layout's vector width in sync with the
+ # emitted copy width; otherwise shapes such as DSv4 N=1536 can leave half of
+ # each logical vector tile uninitialized.
+ if use_async and copy_bits > 128:
+ copy_bits = 128
if use_async and copy_bits < 128:
use_async = False
@@ -284,6 +288,16 @@ def _forward_launch_overrides(
nt_default: int | None = None
cluster_n_default: int | None = None
+ if (
+ dtype.width == 16
+ and weight_dtype is not None
+ and weight_dtype.width == 16
+ and N == 1536
+ and direct_gmem
+ and M >= 4096
+ ):
+ tpr_default = 32
+ nt_default = 32
if (
dtype.width == 16
and weight_dtype is not None
diff --git a/oink/src/kernelagent_oink/blackwell/_rmsnorm_simple_weightonly.py b/oink/src/kernelagent_oink/blackwell/_rmsnorm_simple_weightonly.py
index 2a790f6f..d9d10734 100644
--- a/oink/src/kernelagent_oink/blackwell/_rmsnorm_simple_weightonly.py
+++ b/oink/src/kernelagent_oink/blackwell/_rmsnorm_simple_weightonly.py
@@ -1,8 +1,22 @@
+"""Measured same-dtype bf16 RMSNorm forward specializations.
+
+This module implements a narrow fast path for row-major bf16 tensors with a
+bf16 1D weight and no residual/bias/rstd outputs. The math is
+``y = x * rsqrt(mean(x * x) + eps) * weight`` with fp32 reduction/multiply and
+bf16 output. The shape table below is deliberately measured and narrow; shapes
+not listed fall back to the generic RMSNorm pointer path.
+"""
+
+from dataclasses import dataclass
+
import cuda.bindings.driver as cuda
import torch
-from dataclasses import dataclass
from torch import Tensor
+from kernelagent_oink.blackwell._cutedsl_cache import ensure_versioned_cutedsl_cache_dir
+
+ensure_versioned_cutedsl_cache_dir()
+
import cutlass
import cutlass.cute as cute
import cutlass.utils as utils
@@ -19,6 +33,10 @@
_COMPILED_CACHE: dict[tuple[object, int, int, int], object] = {}
_SIMPLE_WEIGHTONLY_SHAPES: dict[tuple[int, int], tuple[int, int]] = {
+ # DeepSeek-V4-Flash q_lora same-dtype RMSNorm shape. Larger M and the
+ # kv/per-head N=512 cases are faster through the generic pointer path on SM103.
+ (4096, 1536): (96, 96),
+ # DeepSeek-V3 hidden-state same-dtype RMSNorm shapes.
(4096, 6144): (192, 192),
(4096, 7168): (224, 224),
(4096, 8192): (256, 256),
@@ -66,7 +84,12 @@ def cache_key(self) -> tuple[object, int, int, int]:
)
@staticmethod
- def make_tv_layout(threads_per_row, rows_per_block, vec_size, num_vec_blocks):
+ def make_tv_layout(
+ threads_per_row: int,
+ rows_per_block: int,
+ vec_size: int,
+ num_vec_blocks: int,
+ ):
shape = ((threads_per_row, rows_per_block), (vec_size, num_vec_blocks))
stride = (
(vec_size * rows_per_block, 1),
@@ -74,7 +97,7 @@ def make_tv_layout(threads_per_row, rows_per_block, vec_size, num_vec_blocks):
)
return shape, stride
- def smem_bytes(self):
+ def smem_bytes(self) -> int:
return (
self.rows_per_block * self.cols_per_tile * (self.dtype.width // 8)
+ self.rows_per_block * self.warps_per_row * 4
@@ -210,9 +233,6 @@ def kernel(
cute.copy(copy_atom_store, tXrO, tXgO)
-def _can_use_simple_weightonly(x: Tensor, weight: Tensor, out: Tensor) -> bool:
- return _get_simple_weightonly_config(x, weight, out) is not None
-
def _get_simple_weightonly_config(
x: Tensor,
diff --git a/oink/src/kernelagent_oink/blackwell/cross_entropy.py b/oink/src/kernelagent_oink/blackwell/cross_entropy.py
index d8b37ea2..06f72084 100644
--- a/oink/src/kernelagent_oink/blackwell/cross_entropy.py
+++ b/oink/src/kernelagent_oink/blackwell/cross_entropy.py
@@ -39,10 +39,8 @@
from __future__ import annotations
-import importlib.metadata
import math
import os
-import re
from typing import Literal, Optional, Type
import torch
@@ -50,21 +48,9 @@
import cuda.bindings.driver as cuda # provided by NVIDIA cuda-python
-# CuTeDSL caches generated MLIR into a tempdir under a global default
-# (`/tmp/$USER/cutlass_python_cache`). The cache bytecode format can differ across
-# `nvidia-cutlass-dsl` versions, and cross-version cache sharing causes noisy
-# warnings (and disables cache reuse).
-if "CUTE_DSL_CACHE_DIR" not in os.environ:
- try:
- _dsl_ver = importlib.metadata.version("nvidia-cutlass-dsl")
- except Exception:
- _dsl_ver = "unknown"
- _dsl_ver = re.sub(r"[^0-9A-Za-z]+", "_", _dsl_ver)
- _user = os.environ.get("USER") or os.environ.get("USERNAME") or "user"
- _tmp = os.environ.get("TMPDIR") or "/tmp"
- os.environ["CUTE_DSL_CACHE_DIR"] = os.path.join(
- _tmp, _user, f"cutlass_python_cache_{_dsl_ver}"
- )
+from kernelagent_oink.blackwell._cutedsl_cache import ensure_versioned_cutedsl_cache_dir
+
+ensure_versioned_cutedsl_cache_dir()
try:
import cutlass # type: ignore # noqa: F401
diff --git a/oink/src/kernelagent_oink/blackwell/layernorm.py b/oink/src/kernelagent_oink/blackwell/layernorm.py
index ada51ecd..4e5190b7 100644
--- a/oink/src/kernelagent_oink/blackwell/layernorm.py
+++ b/oink/src/kernelagent_oink/blackwell/layernorm.py
@@ -30,10 +30,7 @@
from __future__ import annotations
-import importlib.metadata
import math
-import os
-import re
import operator
from typing import Optional, Tuple, Type
@@ -42,21 +39,9 @@
import cuda.bindings.driver as cuda # provided by NVIDIA cuda-python
-# CuTeDSL caches generated MLIR into a tempdir under a global default
-# (`/tmp/$USER/cutlass_python_cache`). The cache bytecode format can differ across
-# `nvidia-cutlass-dsl` versions, and cross-version cache sharing causes noisy
-# warnings (and disables cache reuse).
-if "CUTE_DSL_CACHE_DIR" not in os.environ:
- try:
- _dsl_ver = importlib.metadata.version("nvidia-cutlass-dsl")
- except Exception:
- _dsl_ver = "unknown"
- _dsl_ver = re.sub(r"[^0-9A-Za-z]+", "_", _dsl_ver)
- _user = os.environ.get("USER") or os.environ.get("USERNAME") or "user"
- _tmp = os.environ.get("TMPDIR") or "/tmp"
- os.environ["CUTE_DSL_CACHE_DIR"] = os.path.join(
- _tmp, _user, f"cutlass_python_cache_{_dsl_ver}"
- )
+from kernelagent_oink.blackwell._cutedsl_cache import ensure_versioned_cutedsl_cache_dir
+
+ensure_versioned_cutedsl_cache_dir()
try:
import cutlass # type: ignore # noqa: F401
diff --git a/oink/src/kernelagent_oink/blackwell/oink_custom_ops.py b/oink/src/kernelagent_oink/blackwell/oink_custom_ops.py
index 282514d3..348b4f57 100644
--- a/oink/src/kernelagent_oink/blackwell/oink_custom_ops.py
+++ b/oink/src/kernelagent_oink/blackwell/oink_custom_ops.py
@@ -16,7 +16,7 @@
Torch custom ops wrapping Oink's Blackwell RMSNorm kernels.
These ops are designed to be:
-- Architecture-aware (use CuTeDSL SM100 kernels when available, fall back
+- Architecture-aware (use CuTeDSL Blackwell SM10x kernels when available, fall back
to a safe reference elsewhere).
- Layout-preserving for 2D row-major inputs, including padded MLA-style
layouts where stride(0) > N and stride(1) == 1.
@@ -69,7 +69,7 @@ def _get_rmsnorm_mod():
def _get_sm(device: torch.device | None = None) -> int:
- """Return SM version as an int (e.g., 100 for SM100 / Blackwell)."""
+ """Return SM version as an int (e.g., 103 for SM103 / Blackwell)."""
if device is None:
device = torch.device("cuda")
major, minor = torch.cuda.get_device_capability(device)
@@ -95,7 +95,7 @@ def oink_rmsnorm(
dimension stride(0) may be larger than N (padded-row layouts), and
will be preserved on the fast CuTeDSL path.
- On SM100 (and newer), this dispatches to the tuned CuTeDSL Blackwell
+ On Blackwell SM10x (SM100 and newer), this dispatches to the tuned CuTeDSL Blackwell
RMSNorm kernel in rmsnorm.rmsnorm_forward, which in turn selects the
best internal schedule (including DSv3-specific stage-2 kernels where
applicable) and preserves the input's 2D stride when using the
@@ -111,7 +111,7 @@ def oink_rmsnorm(
sm = _get_sm(x.device)
_rms = _get_rmsnorm_mod()
if sm >= 100:
- # Use the tuned CuTeDSL SM100 kernel. The public API already
+ # Use the tuned CuTeDSL Blackwell kernel. The public API already
# contains all necessary gating and layout checks internally.
y, _rstd, _res = _rms.rmsnorm_forward(
x,
@@ -186,13 +186,13 @@ def oink_fused_add_rms_norm(
_rms = _get_rmsnorm_mod()
if sm < 100:
- # Non-SM100 fallback: keep semantics in-place (correctness-first).
+ # Non-SM10x fallback: keep semantics in-place (correctness-first).
residual.add_(x)
y = _rms.rmsnorm_ref(residual, w=weight, b=None, residual=None, eps=eps)
x.copy_(y)
return None
- # SM100+: prefer the lowest-overhead in-place entrypoint (returns None).
+ # SM10x+: prefer the lowest-overhead in-place entrypoint (returns None).
if hasattr(_rms, "fused_add_rmsnorm_inplace_"):
_rms.fused_add_rmsnorm_inplace_( # type: ignore[misc]
x,
diff --git a/oink/src/kernelagent_oink/blackwell/softmax.py b/oink/src/kernelagent_oink/blackwell/softmax.py
index 394ab486..3ee93c56 100644
--- a/oink/src/kernelagent_oink/blackwell/softmax.py
+++ b/oink/src/kernelagent_oink/blackwell/softmax.py
@@ -26,9 +26,6 @@
from __future__ import annotations
-import importlib.metadata
-import os
-import re
from typing import Type
import torch
@@ -36,21 +33,9 @@
import cuda.bindings.driver as cuda # provided by NVIDIA cuda-python
-# CuTeDSL caches generated MLIR into a tempdir under a global default
-# (`/tmp/$USER/cutlass_python_cache`). The cache bytecode format can differ across
-# `nvidia-cutlass-dsl` versions, and cross-version cache sharing causes noisy
-# warnings (and disables cache reuse).
-if "CUTE_DSL_CACHE_DIR" not in os.environ:
- try:
- _dsl_ver = importlib.metadata.version("nvidia-cutlass-dsl")
- except Exception:
- _dsl_ver = "unknown"
- _dsl_ver = re.sub(r"[^0-9A-Za-z]+", "_", _dsl_ver)
- _user = os.environ.get("USER") or os.environ.get("USERNAME") or "user"
- _tmp = os.environ.get("TMPDIR") or "/tmp"
- os.environ["CUTE_DSL_CACHE_DIR"] = os.path.join(
- _tmp, _user, f"cutlass_python_cache_{_dsl_ver}"
- )
+from kernelagent_oink.blackwell._cutedsl_cache import ensure_versioned_cutedsl_cache_dir
+
+ensure_versioned_cutedsl_cache_dir()
try:
import cutlass # type: ignore # noqa: F401
From 88b85ffe2bd125cbfe897db80e1a4c1b132246fb Mon Sep 17 00:00:00 2001
From: Laura Wang <3700467+Laurawly@users.noreply.github.com>
Date: Wed, 29 Apr 2026 14:41:33 -0700
Subject: [PATCH 2/4] Document fused RMSNorm benchmark results
---
oink/benchmarks/README.md | 56 ++++++++++++++++++++++++++++++++-------
1 file changed, 47 insertions(+), 9 deletions(-)
diff --git a/oink/benchmarks/README.md b/oink/benchmarks/README.md
index c99966e1..9685000b 100644
--- a/oink/benchmarks/README.md
+++ b/oink/benchmarks/README.md
@@ -153,19 +153,57 @@ python benchmarks/benchmark/benchmark_rmsnorm_sm100.py --dtype bf16 --weight-dty
### Fused Add + RMSNorm (vLLM-style, in-place)
This is a good roofline case study kernel (heavy read/write traffic, very little
-extra math):
+extra math). Oink exposes an **in-place** fused op that updates `x` and
+`residual`. Quack's fused kernel writes separate `out` and `residual_out`
+buffers, so the default benchmark baseline (`--quack-baseline kernel_inplace`)
+times Quack plus the copies needed to match Oink's in-place semantics. Use
+`--quack-baseline kernel` to time only the Quack kernel with preallocated
+outputs.
```bash
-CUDA_VISIBLE_DEVICES=0 python benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py --dtype bf16 --M 65536 --N 4096 \
- --json /tmp/fused_add_rmsnorm_sm100_bf16.json
+# DeepSeek-V3 hidden-size sweep
+PYTHONNOUSERSITE=1 CUTE_DSL_ARCH=sm_103a \
+ python benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py \
+ --dtype bf16 --dsv3 --iters 80 --warmup-ms 15 \
+ --quack-baseline kernel_inplace \
+ --json /tmp/oink_sm103_fused_add_rmsnorm_dsv3_bf16.json
+
+# DeepSeek-V4-Flash hidden-state sweep (N=7168)
+PYTHONNOUSERSITE=1 CUTE_DSL_ARCH=sm_103a \
+ python benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py \
+ --dtype bf16 --dsv4 --iters 80 --warmup-ms 15 \
+ --quack-baseline kernel_inplace \
+ --json /tmp/oink_sm103_fused_add_rmsnorm_dsv4_bf16.json
```
-Note on the Quack baseline: Oink exposes an **in-place** fused op (updates `x`
-and `residual`). Quack’s fused kernel produces `out` and `residual_out`
-out-of-place, so by default the benchmark times `quack::_rmsnorm_fwd` **plus**
-two explicit copies (`x.copy_(out)`, `residual.copy_(residual_out)`) to match the
-in-place semantics. Use `--quack-baseline kernel` to time only the Quack fused
-kernel with preallocated outputs.
+Current GB300 / SM103 BF16 results from correctness-gated runs:
+
+| suite | rows | speedup vs Quack (min / geomean / max) |
+|---|---:|---:|
+| DSv3 fused-add RMSNorm | 9 | 2.022x / 2.045x / 2.089x |
+| DSv4 fused-add RMSNorm | 3 | 2.030x / 2.192x / 2.521x |
+
+DSv3 per-shape results:
+
+| M | N | Oink ms | Quack ms | speedup | Oink TB/s |
+|---:|---:|---:|---:|---:|---:|
+| 4096 | 6144 | 0.0360 | 0.0727 | 2.022x | 5.598 |
+| 4096 | 7168 | 0.0396 | 0.0828 | 2.089x | 5.926 |
+| 4096 | 8192 | 0.0479 | 0.0993 | 2.076x | 5.610 |
+| 16384 | 6144 | 0.1206 | 0.2463 | 2.043x | 6.678 |
+| 16384 | 7168 | 0.1393 | 0.2830 | 2.031x | 6.742 |
+| 16384 | 8192 | 0.1574 | 0.3212 | 2.040x | 6.821 |
+| 65536 | 6144 | 0.4575 | 0.9285 | 2.030x | 7.041 |
+| 65536 | 7168 | 0.5329 | 1.0785 | 2.024x | 7.052 |
+| 65536 | 8192 | 0.6077 | 1.2466 | 2.052x | 7.068 |
+
+DSv4 per-shape results:
+
+| M | N | Oink ms | Quack ms | speedup | Oink TB/s |
+|---:|---:|---:|---:|---:|---:|
+| 4096 | 7168 | 0.0415 | 0.1047 | 2.521x | 5.655 |
+| 16384 | 7168 | 0.1388 | 0.2855 | 2.057x | 6.769 |
+| 65536 | 7168 | 0.5314 | 1.0785 | 2.030x | 7.072 |
### RMSNorm backward
From 8fd781906bd5e498e69474d2a89df4edbb8f217d Mon Sep 17 00:00:00 2001
From: Laura Wang <3700467+Laurawly@users.noreply.github.com>
Date: Wed, 29 Apr 2026 14:56:11 -0700
Subject: [PATCH 3/4] Fix Oink ruff import-order checks
---
.../kernelagent_oink/blackwell/_rmsnorm_simple_weightonly.py | 2 ++
oink/src/kernelagent_oink/blackwell/cross_entropy.py | 2 +-
oink/src/kernelagent_oink/blackwell/layernorm.py | 1 +
oink/src/kernelagent_oink/blackwell/softmax.py | 1 +
4 files changed, 5 insertions(+), 1 deletion(-)
diff --git a/oink/src/kernelagent_oink/blackwell/_rmsnorm_simple_weightonly.py b/oink/src/kernelagent_oink/blackwell/_rmsnorm_simple_weightonly.py
index d9d10734..6c7d9065 100644
--- a/oink/src/kernelagent_oink/blackwell/_rmsnorm_simple_weightonly.py
+++ b/oink/src/kernelagent_oink/blackwell/_rmsnorm_simple_weightonly.py
@@ -6,6 +6,8 @@
bf16 output. The shape table below is deliberately measured and narrow; shapes
not listed fall back to the generic RMSNorm pointer path.
"""
+# ruff: noqa: E402 # CuTeDSL cache setup must run before importing cutlass.
+
from dataclasses import dataclass
diff --git a/oink/src/kernelagent_oink/blackwell/cross_entropy.py b/oink/src/kernelagent_oink/blackwell/cross_entropy.py
index 06f72084..5bb15685 100644
--- a/oink/src/kernelagent_oink/blackwell/cross_entropy.py
+++ b/oink/src/kernelagent_oink/blackwell/cross_entropy.py
@@ -1,3 +1,4 @@
+# ruff: noqa: E402 # CuTeDSL cache setup must run before importing cutlass.
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -40,7 +41,6 @@
from __future__ import annotations
import math
-import os
from typing import Literal, Optional, Type
import torch
diff --git a/oink/src/kernelagent_oink/blackwell/layernorm.py b/oink/src/kernelagent_oink/blackwell/layernorm.py
index 4e5190b7..6b4b9c72 100644
--- a/oink/src/kernelagent_oink/blackwell/layernorm.py
+++ b/oink/src/kernelagent_oink/blackwell/layernorm.py
@@ -1,3 +1,4 @@
+# ruff: noqa: E402 # CuTeDSL cache setup must run before importing cutlass.
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
diff --git a/oink/src/kernelagent_oink/blackwell/softmax.py b/oink/src/kernelagent_oink/blackwell/softmax.py
index 3ee93c56..d364c5cf 100644
--- a/oink/src/kernelagent_oink/blackwell/softmax.py
+++ b/oink/src/kernelagent_oink/blackwell/softmax.py
@@ -1,3 +1,4 @@
+# ruff: noqa: E402 # CuTeDSL cache setup must run before importing cutlass.
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
From 956884ad07cd2fa14a636817e597b2acad23396e Mon Sep 17 00:00:00 2001
From: Laura Wang <3700467+Laurawly@users.noreply.github.com>
Date: Wed, 29 Apr 2026 15:02:54 -0700
Subject: [PATCH 4/4] Apply ruff formatting to Oink files
---
oink/benchmarks/readme/run_sm100_suite.py | 510 +++++++++---------
.../blackwell/_rmsnorm_simple_weightonly.py | 2 -
2 files changed, 256 insertions(+), 256 deletions(-)
diff --git a/oink/benchmarks/readme/run_sm100_suite.py b/oink/benchmarks/readme/run_sm100_suite.py
index 71920e38..4ef1bf46 100644
--- a/oink/benchmarks/readme/run_sm100_suite.py
+++ b/oink/benchmarks/readme/run_sm100_suite.py
@@ -196,260 +196,262 @@ def script(name: str) -> str:
]
)
- runs.extend([
- (
- "rmsnorm_fwd_quack_suite_wfp32",
- [
- py,
- script("benchmark_rmsnorm_sm100.py"),
- *common,
- "--weight-dtype",
- "fp32",
- "--quack-suite",
- "--iters",
- "200",
- "--warmup-ms",
- "25",
- "--json",
- os.path.join(out_dir, "rmsnorm_fwd_quack_suite_wfp32.json"),
- ],
- ),
- (
- "rmsnorm_fwd_dsv3_wfp32",
- [
- py,
- script("benchmark_rmsnorm_sm100.py"),
- *common,
- "--weight-dtype",
- "fp32",
- "--dsv3",
- "--iters",
- "200",
- "--warmup-ms",
- "25",
- "--json",
- os.path.join(out_dir, "rmsnorm_fwd_dsv3_wfp32.json"),
- ],
- ),
- (
- "rmsnorm_bwd_quack_suite_wfp32",
- [
- py,
- script("benchmark_rmsnorm_bwd_sm100.py"),
- *common,
- "--weight-dtype",
- "fp32",
- "--quack-suite",
- "--iters",
- "100",
- "--warmup-ms",
- "25",
- "--json",
- os.path.join(out_dir, "rmsnorm_bwd_quack_suite_wfp32.json"),
- ],
- ),
- (
- "rmsnorm_bwd_dsv3_wfp32",
- [
- py,
- script("benchmark_rmsnorm_bwd_sm100.py"),
- *common,
- "--weight-dtype",
- "fp32",
- "--dsv3",
- "--iters",
- "100",
- "--warmup-ms",
- "25",
- "--json",
- os.path.join(out_dir, "rmsnorm_bwd_dsv3_wfp32.json"),
- ],
- ),
- # vLLM inference-style RMSNorm (weight dtype == activation dtype).
- (
- "rmsnorm_fwd_quack_suite_wsame",
- [
- py,
- script("benchmark_rmsnorm_sm100.py"),
- *common,
- "--weight-dtype",
- "same",
- "--quack-suite",
- "--iters",
- "200",
- "--warmup-ms",
- "25",
- "--json",
- os.path.join(out_dir, "rmsnorm_fwd_quack_suite_wsame.json"),
- ],
- ),
- (
- "rmsnorm_fwd_dsv3_wsame",
- [
- py,
- script("benchmark_rmsnorm_sm100.py"),
- *common,
- "--weight-dtype",
- "same",
- "--dsv3",
- "--iters",
- "200",
- "--warmup-ms",
- "25",
- "--json",
- os.path.join(out_dir, "rmsnorm_fwd_dsv3_wsame.json"),
- ],
- ),
- (
- "rmsnorm_bwd_quack_suite_wsame",
- [
- py,
- script("benchmark_rmsnorm_bwd_sm100.py"),
- *common,
- "--weight-dtype",
- "same",
- "--quack-suite",
- "--iters",
- "100",
- "--warmup-ms",
- "25",
- "--json",
- os.path.join(out_dir, "rmsnorm_bwd_quack_suite_wsame.json"),
- ],
- ),
- (
- "rmsnorm_bwd_dsv3_wsame",
- [
- py,
- script("benchmark_rmsnorm_bwd_sm100.py"),
- *common,
- "--weight-dtype",
- "same",
- "--dsv3",
- "--iters",
- "100",
- "--warmup-ms",
- "25",
- "--json",
- os.path.join(out_dir, "rmsnorm_bwd_dsv3_wsame.json"),
- ],
- ),
- (
- "fused_add_rmsnorm_dsv3",
- [
- py,
- script("benchmark_fused_add_rmsnorm_sm100.py"),
- *common,
- "--dsv3",
- "--quack-baseline",
- "kernel_inplace",
- "--iters",
- "200",
- "--warmup-ms",
- "25",
- "--json",
- os.path.join(out_dir, "fused_add_rmsnorm_dsv3.json"),
- ],
- ),
- (
- "softmax_fwd_bwd_quack_suite",
- [
- py,
- script("benchmark_softmax_sm100.py"),
- *common,
- "--mode",
- "fwd_bwd",
- "--quack-suite",
- "--iters",
- "50",
- "--warmup-ms",
- "25",
- "--json",
- os.path.join(out_dir, "softmax_fwd_bwd_quack_suite.json"),
- ],
- ),
- (
- "softmax_fwd_bwd_dsv3",
- [
- py,
- script("benchmark_softmax_sm100.py"),
- *common,
- "--mode",
- "fwd_bwd",
- "--dsv3",
- "--iters",
- "50",
- "--warmup-ms",
- "25",
- "--json",
- os.path.join(out_dir, "softmax_fwd_bwd_dsv3.json"),
- ],
- ),
- (
- "cross_entropy_fwd_bwd_quack_suite",
- [
- py,
- script("benchmark_cross_entropy_sm100.py"),
- *common,
- "--mode",
- "fwd_bwd",
- "--quack-suite",
- "--iters",
- "50",
- "--warmup-ms",
- "25",
- "--json",
- os.path.join(out_dir, "cross_entropy_fwd_bwd_quack_suite.json"),
- ],
- ),
- (
- "cross_entropy_fwd_bwd_dsv3",
- [
- py,
- script("benchmark_cross_entropy_sm100.py"),
- *common,
- "--mode",
- "fwd_bwd",
- "--dsv3",
- "--iters",
- "50",
- "--warmup-ms",
- "25",
- "--json",
- os.path.join(out_dir, "cross_entropy_fwd_bwd_dsv3.json"),
- ],
- ),
- (
- "layernorm_fwd_quack_suite",
- [
- py,
- script("benchmark_layernorm_sm100.py"),
- *common,
- "--quack-suite",
- "--iters",
- "200",
- "--warmup-ms",
- "25",
- "--json",
- os.path.join(out_dir, "layernorm_fwd_quack_suite.json"),
- ],
- ),
- (
- "layernorm_fwd_dsv3",
- [
- py,
- script("benchmark_layernorm_sm100.py"),
- *common,
- "--dsv3",
- "--iters",
- "200",
- "--warmup-ms",
- "25",
- "--json",
- os.path.join(out_dir, "layernorm_fwd_dsv3.json"),
- ],
- ),
- ])
+ runs.extend(
+ [
+ (
+ "rmsnorm_fwd_quack_suite_wfp32",
+ [
+ py,
+ script("benchmark_rmsnorm_sm100.py"),
+ *common,
+ "--weight-dtype",
+ "fp32",
+ "--quack-suite",
+ "--iters",
+ "200",
+ "--warmup-ms",
+ "25",
+ "--json",
+ os.path.join(out_dir, "rmsnorm_fwd_quack_suite_wfp32.json"),
+ ],
+ ),
+ (
+ "rmsnorm_fwd_dsv3_wfp32",
+ [
+ py,
+ script("benchmark_rmsnorm_sm100.py"),
+ *common,
+ "--weight-dtype",
+ "fp32",
+ "--dsv3",
+ "--iters",
+ "200",
+ "--warmup-ms",
+ "25",
+ "--json",
+ os.path.join(out_dir, "rmsnorm_fwd_dsv3_wfp32.json"),
+ ],
+ ),
+ (
+ "rmsnorm_bwd_quack_suite_wfp32",
+ [
+ py,
+ script("benchmark_rmsnorm_bwd_sm100.py"),
+ *common,
+ "--weight-dtype",
+ "fp32",
+ "--quack-suite",
+ "--iters",
+ "100",
+ "--warmup-ms",
+ "25",
+ "--json",
+ os.path.join(out_dir, "rmsnorm_bwd_quack_suite_wfp32.json"),
+ ],
+ ),
+ (
+ "rmsnorm_bwd_dsv3_wfp32",
+ [
+ py,
+ script("benchmark_rmsnorm_bwd_sm100.py"),
+ *common,
+ "--weight-dtype",
+ "fp32",
+ "--dsv3",
+ "--iters",
+ "100",
+ "--warmup-ms",
+ "25",
+ "--json",
+ os.path.join(out_dir, "rmsnorm_bwd_dsv3_wfp32.json"),
+ ],
+ ),
+ # vLLM inference-style RMSNorm (weight dtype == activation dtype).
+ (
+ "rmsnorm_fwd_quack_suite_wsame",
+ [
+ py,
+ script("benchmark_rmsnorm_sm100.py"),
+ *common,
+ "--weight-dtype",
+ "same",
+ "--quack-suite",
+ "--iters",
+ "200",
+ "--warmup-ms",
+ "25",
+ "--json",
+ os.path.join(out_dir, "rmsnorm_fwd_quack_suite_wsame.json"),
+ ],
+ ),
+ (
+ "rmsnorm_fwd_dsv3_wsame",
+ [
+ py,
+ script("benchmark_rmsnorm_sm100.py"),
+ *common,
+ "--weight-dtype",
+ "same",
+ "--dsv3",
+ "--iters",
+ "200",
+ "--warmup-ms",
+ "25",
+ "--json",
+ os.path.join(out_dir, "rmsnorm_fwd_dsv3_wsame.json"),
+ ],
+ ),
+ (
+ "rmsnorm_bwd_quack_suite_wsame",
+ [
+ py,
+ script("benchmark_rmsnorm_bwd_sm100.py"),
+ *common,
+ "--weight-dtype",
+ "same",
+ "--quack-suite",
+ "--iters",
+ "100",
+ "--warmup-ms",
+ "25",
+ "--json",
+ os.path.join(out_dir, "rmsnorm_bwd_quack_suite_wsame.json"),
+ ],
+ ),
+ (
+ "rmsnorm_bwd_dsv3_wsame",
+ [
+ py,
+ script("benchmark_rmsnorm_bwd_sm100.py"),
+ *common,
+ "--weight-dtype",
+ "same",
+ "--dsv3",
+ "--iters",
+ "100",
+ "--warmup-ms",
+ "25",
+ "--json",
+ os.path.join(out_dir, "rmsnorm_bwd_dsv3_wsame.json"),
+ ],
+ ),
+ (
+ "fused_add_rmsnorm_dsv3",
+ [
+ py,
+ script("benchmark_fused_add_rmsnorm_sm100.py"),
+ *common,
+ "--dsv3",
+ "--quack-baseline",
+ "kernel_inplace",
+ "--iters",
+ "200",
+ "--warmup-ms",
+ "25",
+ "--json",
+ os.path.join(out_dir, "fused_add_rmsnorm_dsv3.json"),
+ ],
+ ),
+ (
+ "softmax_fwd_bwd_quack_suite",
+ [
+ py,
+ script("benchmark_softmax_sm100.py"),
+ *common,
+ "--mode",
+ "fwd_bwd",
+ "--quack-suite",
+ "--iters",
+ "50",
+ "--warmup-ms",
+ "25",
+ "--json",
+ os.path.join(out_dir, "softmax_fwd_bwd_quack_suite.json"),
+ ],
+ ),
+ (
+ "softmax_fwd_bwd_dsv3",
+ [
+ py,
+ script("benchmark_softmax_sm100.py"),
+ *common,
+ "--mode",
+ "fwd_bwd",
+ "--dsv3",
+ "--iters",
+ "50",
+ "--warmup-ms",
+ "25",
+ "--json",
+ os.path.join(out_dir, "softmax_fwd_bwd_dsv3.json"),
+ ],
+ ),
+ (
+ "cross_entropy_fwd_bwd_quack_suite",
+ [
+ py,
+ script("benchmark_cross_entropy_sm100.py"),
+ *common,
+ "--mode",
+ "fwd_bwd",
+ "--quack-suite",
+ "--iters",
+ "50",
+ "--warmup-ms",
+ "25",
+ "--json",
+ os.path.join(out_dir, "cross_entropy_fwd_bwd_quack_suite.json"),
+ ],
+ ),
+ (
+ "cross_entropy_fwd_bwd_dsv3",
+ [
+ py,
+ script("benchmark_cross_entropy_sm100.py"),
+ *common,
+ "--mode",
+ "fwd_bwd",
+ "--dsv3",
+ "--iters",
+ "50",
+ "--warmup-ms",
+ "25",
+ "--json",
+ os.path.join(out_dir, "cross_entropy_fwd_bwd_dsv3.json"),
+ ],
+ ),
+ (
+ "layernorm_fwd_quack_suite",
+ [
+ py,
+ script("benchmark_layernorm_sm100.py"),
+ *common,
+ "--quack-suite",
+ "--iters",
+ "200",
+ "--warmup-ms",
+ "25",
+ "--json",
+ os.path.join(out_dir, "layernorm_fwd_quack_suite.json"),
+ ],
+ ),
+ (
+ "layernorm_fwd_dsv3",
+ [
+ py,
+ script("benchmark_layernorm_sm100.py"),
+ *common,
+ "--dsv3",
+ "--iters",
+ "200",
+ "--warmup-ms",
+ "25",
+ "--json",
+ os.path.join(out_dir, "layernorm_fwd_dsv3.json"),
+ ],
+ ),
+ ]
+ )
print(f"Writing results to: {out_dir}", flush=True)
for name, cmd in runs:
diff --git a/oink/src/kernelagent_oink/blackwell/_rmsnorm_simple_weightonly.py b/oink/src/kernelagent_oink/blackwell/_rmsnorm_simple_weightonly.py
index 6c7d9065..63e50832 100644
--- a/oink/src/kernelagent_oink/blackwell/_rmsnorm_simple_weightonly.py
+++ b/oink/src/kernelagent_oink/blackwell/_rmsnorm_simple_weightonly.py
@@ -8,7 +8,6 @@
"""
# ruff: noqa: E402 # CuTeDSL cache setup must run before importing cutlass.
-
from dataclasses import dataclass
import cuda.bindings.driver as cuda
@@ -235,7 +234,6 @@ def kernel(
cute.copy(copy_atom_store, tXrO, tXgO)
-
def _get_simple_weightonly_config(
x: Tensor,
weight: Tensor,