Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 47 additions & 89 deletions oink/README.md
Original file line number Diff line number Diff line change
@@ -1,65 +1,61 @@
# KernelAgent-Oink

KernelAgent-Oink is a small **CuTeDSL (CUTLASS DSL) kernel library** for
**NVIDIA Blackwell (SM10x / GB200 / GB300 / B200-class)**, bundled as a lightweight
Python package that can be used standalone or as a **vLLM general plugin**.
KernelAgent-Oink is a lightweight **CuTeDSL (CUTLASS DSL) kernel package** for
NVIDIA Blackwell **SM10x** GPUs. It can be used standalone or loaded as a
**vLLM general plugin**.

At the moment, the vLLM integration exposes the following `torch.library.custom_op`
entrypoints under the `oink::` namespace:
Current custom ops:

- `torch.ops.oink.rmsnorm(x, weight, eps) -> Tensor`
- `torch.ops.oink.fused_add_rms_norm(x, residual, weight, eps) -> None` (in-place)

The package also includes additional SM100 kernels used by the benchmark suite:
LayerNorm, Softmax (fwd+bwd), and CrossEntropy (fwd+bwd).
The repo also contains benchmark-facing Blackwell kernels for LayerNorm, Softmax,
and CrossEntropy.

## Requirements

- GPU: **SM10x (Blackwell)** for the fast CuTeDSL paths. On other GPUs, Oink falls back to
reference PyTorch implementations for correctness.
- Python dependencies:
- `nvidia-cutlass-dsl` (CuTeDSL)
- `cuda-python`
- `torch` (provided by your environment / vLLM)
- Blackwell GPU for optimized CuTeDSL paths; other GPUs use correctness-first
PyTorch fallbacks.
- `nvidia-cutlass-dsl>=4.4.2`
- `cuda-python`
- `torch` from the surrounding environment / vLLM

Recommended env vars:

```bash
export CUTE_DSL_ARCH=sm_100a
export PYTORCH_ALLOC_CONF=expandable_segments:True
export CUTE_DSL_ARCH=sm_103a # GB300 / SM103
# export CUTE_DSL_ARCH=sm_100a # GB200/B200 / SM100
```

On **GB300 / SM103**, prefer:

```bash
export CUTE_DSL_ARCH=sm_103a
```

## Install (editable)
## Install

From the `KernelAgent` repo root:

```bash
pip install -e ./oink
pip install -e "./oink[bench]" # optional benchmark/plot deps
```

For running the in-repo benchmark suite / plots:
A reproducible GB300 benchmark environment used for the results below:

```bash
pip install -e "./oink[bench]"
conda create -y -n cute python=3.12
conda run -n cute python -m pip install --upgrade pip setuptools wheel packaging ninja
conda run -n cute python -m pip install --upgrade --index-url https://download.pytorch.org/whl/cu130 torch
conda run -n cute python -m pip install 'nvidia-cutlass-dsl==4.4.2' cuda-python triton matplotlib
conda run -n cute python -m pip install -e './oink[bench]'
```

## Usage

### vLLM (general plugin)

1) Enable the plugin:
### vLLM plugin

```bash
export VLLM_USE_OINK_RMSNORM=1
```

2) Ensure vLLM keeps `rms_norm` as a custom op when using `torch.compile` / CUDA graphs:
When using `torch.compile` / CUDA graphs, keep vLLM RMSNorm as a custom op:

```python
from vllm import LLM
Expand All @@ -72,12 +68,7 @@ llm = LLM(
)
```

Without `+rms_norm`, Inductor may fuse RMSNorm into larger kernels and neither
vLLM’s CUDA RMSNorm nor Oink will run.

### Direct PyTorch usage (manual op registration)

For standalone use (outside vLLM), register the custom ops once:
### Direct PyTorch

```python
import kernelagent_oink
Expand All @@ -92,73 +83,40 @@ y = torch.ops.oink.rmsnorm(x, w, 1e-6)

## Benchmarks

### GB200 / B200 (SM100) benchmark suite

The repo includes a Quack-style benchmark suite (tables + SVG plots) to compare
Oink against Quack and to reproduce the reported speedups. The pre-generated
plots below were measured on **GB200 / B200-class SM100** systems.

In short, Oink’s edge comes from lower pointer-path launch overhead plus Blackwell-tuned shape routing for both hot small-`M` and larger RMSNorm rows.

On the current B200 forward sweep, Oink holds `1.12x` / `1.06x` geomean over Quack for same-dtype weights on the Quack-suite / DSv3 sets, and `1.18x` / `1.06x` for fp32 weights, with worst output rel-L2 `1.45e-5` (Quack `2.01e-5`).

- How to run + methodology: `oink/benchmarks/README.md`
- Pre-generated plots: `oink/benchmarks/media/`

<div align="center">
<img src="benchmarks/media/sm100_bf16_oink_vs_quack_with_layernorm.svg" alt="SM100 BF16: Oink vs Quack (Quack-suite)">
</div>

<div align="center">
<img src="benchmarks/media/sm100_bf16_oink_vs_quack_dsv3_all.svg" alt="SM100 BF16: Oink vs Quack (DSv3-like shapes)">
</div>

### GB300 (SM103) Q/K-norm results
Benchmark details and commands are in [`benchmarks/README.md`](benchmarks/README.md).
Reported numbers are correctness-gated against PyTorch references before timing.

We also benchmarked the real Llama4x-style Q/K-norm workload on **GB300
(SM103)** using non-contiguous `q` / `k` views produced by `qkv.split()`. This
benchmark reports both the direct CuTeDSL/CUTLASS baseline and the optimized
Oink path for the production strided `[M, N]` views. The CuTeDSL/CUTLASS
baseline here is a **Q/K-norm adaptation** derived from the
[CUTLASS CuTeDSL Blackwell RMSNorm example](https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/rmsnorm.py),
not the example kernel used unchanged.
Current GB300 / SM103 setup:

For roofline context, we also plot the same workload using a dedicated
useful-bandwidth harness: median CUDA-event timing plus a logical IO model of
one read + one write of the fused `[M, N]` tensor. This is the physically
meaningful view for comparing against the measured practical GB300 BF16 stream
roof, whereas the steady-state CUDA-graph replay medians below are better read
as a latency view.
- NVIDIA GB300, capability `(10, 3)`, `CUTE_DSL_ARCH=sm_103a`
- `torch==2.11.0+cu130`, CUDA `13.0`
- `nvidia-cutlass-dsl==4.4.2`, `cuda-python==13.2.0`
- measured BF16 STREAM-like roof: **7.140 TB/s**

<div align="center">
<img src="benchmarks/media/gb300_bf16_qk_norm_oink_vs_cutedsl_roofline.svg" alt="GB300 BF16: Q/K-norm roofline (Oink vs CuTeDSL)">
<img src="benchmarks/media/sm103_bf16_oink_vs_quack_with_layernorm.svg" alt="SM103 / GB300 BF16 benchmark summary">
</div>

Representative steady-state CUDA-graph replay medians from one GB300 run are
shown below (absolute microseconds may vary slightly run to run, but the
ranking and trend were stable).
Quack-suite BF16 summary (`N=4096`):

- Q path: Oink is roughly **2.4–3.1x faster** than the CuTeDSL baseline on
representative multi-row workloads.
- K path: Oink is roughly **2.0–3.6x faster** on the same sweep.
| op | rows | geomean vs Quack | large-row roofline note |
|---|---:|---:|---|
| RMSNorm fwd, weight=same | 19 | 1.019x | near measured roof on large rows |
| RMSNorm fwd, weight=fp32 | 19 | 1.100x | near measured roof on large rows |
| LayerNorm fwd | 19 | 1.241x | near measured roof on large rows |
| Softmax fwd+bwd | 19 | 1.673x | near measured roof on large rows |
| CrossEntropy fwd+bwd | 19 | 1.635x | mixed memory/SFU behavior |

Takeaways from the GB300 Q/K-norm sweep:
Historical plots remain under `benchmarks/media/`:

- For the user-relevant multi-row workloads, Oink beats the CuTeDSL/CUTLASS
baseline by comfortably more than 20%.
- In the roofline view, Oink gets close to the practical GB300 BF16 streaming
ceiling on the large-row Q/K shapes, while the CuTeDSL baseline stays much
farther from the roof.
- The only cases below 20% are the tiny single-row latency-floor microcases:
Q `M=1` is ~12% faster and K `M=1` is ~6% faster.
- Correctness spot-check from the same harness:
- Q max diff vs eager: `0.03125`
- K max diff vs eager: `0.007812`
- `sm100_*`: historical SM100 / B200 runs.
- `gb300_bf16_qk_norm_oink_vs_cutedsl_roofline.svg`: historical GB300 Q/K-norm
harness, separate from the Quack-suite table above.

## Links

| What | Link |
|---|---|
| Quack (expert baseline) | https://github.com/Dao-AILab/quack |
| KernelAgent (agentic framework) | https://github.com/meta-pytorch/KernelAgent |
| vLLM PR (Oink RMSNorm integration) | https://github.com/vllm-project/vllm/pull/31828 |
| Quack baseline | https://github.com/Dao-AILab/quack |
| KernelAgent | https://github.com/meta-pytorch/KernelAgent |
| vLLM Oink RMSNorm PR | https://github.com/vllm-project/vllm/pull/31828 |
Loading
Loading