diff --git a/oink/benchmarks/benchmark/benchmark_rmsnorm_all.py b/oink/benchmarks/benchmark/benchmark_rmsnorm_all.py new file mode 100644 index 0000000..6d6eb3d --- /dev/null +++ b/oink/benchmarks/benchmark/benchmark_rmsnorm_all.py @@ -0,0 +1,325 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Benchmark aten vs quack vs oink RMSNorm: normal dispatch + CUDA graph. + +All calls go through ``torch.ops.aten._fused_rms_norm``. +Quack is registered via ``torch._native`` (quack PR pattern). +Oink is registered via ``kernelagent_oink.register_all_kernels()``. + +Produces four tables: + - Forward (normal dispatch) + - Forward + Backward (normal dispatch) + - Forward (CUDA graph) + - Forward + Backward (CUDA graph) + +Usage:: + + python oink/benchmarks/benchmark/benchmark_rmsnorm_all.py +""" + +from __future__ import annotations + +import json +import os +import subprocess +import sys +import tempfile + +os.environ.setdefault("TORCH_NATIVE_SKIP_VERSION_CHECK", "1") + + +# --------------------------------------------------------------------------- +# Worker code: runs in a subprocess per mode to avoid cross-contamination. +# --------------------------------------------------------------------------- + +WORKER_CODE = r""" +import json, os, sys +os.environ.setdefault("TORCH_NATIVE_SKIP_VERSION_CHECK", "1") + +import torch +from triton.testing import do_bench + +DTYPE = torch.bfloat16 + +def bench_normal(fn, warmup=50, rep=200): + return do_bench(fn, warmup=warmup, rep=rep, return_mode="median") + +def bench_cudagraph(fn, warmup=50, rep=200): + for _ in range(warmup): + fn() + torch.cuda.synchronize() + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + fn() + torch.cuda.synchronize() + return do_bench(lambda: g.replay(), warmup=10, rep=rep, return_mode="median") + +mode = sys.argv[1] +shapes_json = sys.argv[2] +SHAPES = json.loads(shapes_json) + +if mode == "oink": + import kernelagent_oink + kernelagent_oink.register_all_kernels(force=True) + +# Warm up +for M, N in SHAPES: + x = torch.randn(M, N, dtype=DTYPE, device="cuda") + w = torch.randn(N, dtype=DTYPE, device="cuda") + torch.ops.aten._fused_rms_norm(x, [N], w, 1e-5) +torch.cuda.synchronize() + +results = {} +for M, N in SHAPES: + x = torch.randn(M, N, dtype=DTYPE, device="cuda", requires_grad=True) + w = torch.randn(N, dtype=DTYPE, device="cuda", requires_grad=True) + grad = torch.randn(M, N, dtype=DTYPE, device="cuda") + + # Forward (normal) + def fn_fwd(x=x, w=w, N=N): + return torch.ops.aten._fused_rms_norm(x, [N], w, 1e-5) + fwd_ms = bench_normal(fn_fwd) + + # Forward + Backward (normal) + x_ = x.detach().requires_grad_(True) + w_ = w.detach().requires_grad_(True) + def fn_fwdbwd(x_=x_, w_=w_, N=N, grad=grad): + y, _ = torch.ops.aten._fused_rms_norm(x_, [N], w_, 1e-5) + y.backward(grad) + fwdbwd_ms = bench_normal(fn_fwdbwd) + + # Forward (CUDA graph) + x_g = torch.randn(M, N, dtype=DTYPE, device="cuda") + w_g = torch.randn(N, dtype=DTYPE, device="cuda") + def fn_fwd_g(x=x_g, w=w_g, N=N): + return torch.ops.aten._fused_rms_norm(x, [N], w, 1e-5) + try: + fwd_graph_ms = bench_cudagraph(fn_fwd_g) + except Exception: + fwd_graph_ms = -1.0 + + # Forward + Backward (CUDA graph) + x_gb = torch.randn(M, N, dtype=DTYPE, device="cuda", requires_grad=True) + w_gb = torch.randn(N, dtype=DTYPE, device="cuda", requires_grad=True) + grad_gb = torch.randn(M, N, dtype=DTYPE, device="cuda") + def fn_fwdbwd_g(x=x_gb, w=w_gb, N=N, grad=grad_gb): + y, _ = torch.ops.aten._fused_rms_norm(x, [N], w, 1e-5) + y.backward(grad) + try: + fwdbwd_graph_ms = bench_cudagraph(fn_fwdbwd_g) + except Exception: + fwdbwd_graph_ms = -1.0 + + results[f"{M}x{N}"] = { + "fwd": fwd_ms, + "fwdbwd": fwdbwd_ms, + "fwd_graph": fwd_graph_ms, + "fwdbwd_graph": fwdbwd_graph_ms, + } + +print(json.dumps({"mode": mode, "results": results})) +""" + + +# --------------------------------------------------------------------------- +# Main: orchestrates subprocesses and prints tables. +# --------------------------------------------------------------------------- + +SHAPES = [ + [1, 4096], + [1, 8192], + [32, 4096], + [32, 8192], + [256, 4096], + [256, 8192], + [1024, 4096], + [1024, 8192], + [4096, 4096], + [4096, 8192], + [16384, 4096], + [16384, 8192], + [65536, 4096], + [65536, 8192], +] + +COL_W = { # column widths + "shape": 14, + "ms": 10, + "ratio": 8, +} + + +def find_norm_dir(): + import torch + from pathlib import Path + + d = Path(torch.__file__).parent / "_native" / "ops" / "norm" + return str(d) if d.is_dir() else None + + +def run_mode(mode, norm_dir, shapes): + init_file = os.path.join(norm_dir, "__init__.py") + + if mode in ("aten", "oink"): + with open(init_file, "w") as f: + f.write("") + elif mode == "quack": + with open(init_file, "w") as f: + f.write("from . import rmsnorm_impl # noqa: F401\n") + + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as tmp: + tmp.write(WORKER_CODE) + tmp_path = tmp.name + + try: + result = subprocess.run( + [sys.executable, tmp_path, mode, json.dumps(shapes)], + capture_output=True, + text=True, + timeout=600, + ) + if result.returncode != 0: + print(f" [{mode}] FAILED: {result.stderr[-300:]}", file=sys.stderr) + return None + return json.loads(result.stdout.strip())["results"] + finally: + os.unlink(tmp_path) + + +def _fmt_ms(v): + return f"{v:>{COL_W['ms']}.4f}" if v > 0 else "FAIL".rjust(COL_W["ms"]) + + +def _fmt_ratio(n, d): + if d <= 0 or n <= 0: + return "N/A".rjust(COL_W["ratio"]) + return f"{f'{n / d:.2f}x':>{COL_W['ratio']}}" + + +def print_table(title, subtitle, aten, quack, oink, key): + sw, mw, rw = COL_W["shape"], COL_W["ms"], COL_W["ratio"] + w = [sw, mw, mw, mw, rw, rw, rw] + + def hr(left, mid, right): + return left + mid.join("─" * (c + 2) for c in w) + right + + hdr = ( + f"│ {'Shape (M,N)':^{sw}} " + f"│ {'Aten (ms)':^{mw}} " + f"│ {'Quack (ms)':^{mw}} " + f"│ {'Oink (ms)':^{mw}} " + f"│ {'Q/A':^{rw}} " + f"│ {'O/A':^{rw}} " + f"│ {'O/Q':^{rw}} │" + ) + + print() + print(f" {title}") + print(f" {subtitle}") + print(hr("┌", "┬", "┐")) + print(hdr) + print(hr("├", "┼", "┤")) + + for shape in aten: + M, N = shape.split("x") + a, q, o = aten[shape][key], quack[shape][key], oink[shape][key] + row = ( + f"│ {f'({M},{N})':>{sw}} " + f"│ {_fmt_ms(a)} " + f"│ {_fmt_ms(q)} " + f"│ {_fmt_ms(o)} " + f"│ {_fmt_ratio(a, q)} " + f"│ {_fmt_ratio(a, o)} " + f"│ {_fmt_ratio(q, o)} │" + ) + print(row) + + print(hr("└", "┴", "┘")) + + +def main(): + import torch + + print("=" * 72) + print(" RMSNorm Kernel Benchmark: Aten vs Quack vs Oink") + print("=" * 72) + print(f" Device : {torch.cuda.get_device_name(0)}") + print(f" Torch : {torch.__version__}") + print(" Dtype : bfloat16") + print(" Quack : registered via torch._native (quack PR)") + print(" Oink : registered via kernelagent_oink.register_all_kernels()") + print(" Bench : triton.testing.do_bench (median, 200 reps)") + + norm_dir = find_norm_dir() + if norm_dir is None: + print("ERROR: torch._native/ops/norm/ not found.", file=sys.stderr) + sys.exit(1) + + print() + print("Running aten...") + aten = run_mode("aten", norm_dir, SHAPES) + print("Running quack...") + quack = run_mode("quack", norm_dir, SHAPES) + print("Running oink...") + oink = run_mode("oink", norm_dir, SHAPES) + + # Restore + with open(os.path.join(norm_dir, "__init__.py"), "w") as f: + f.write("from . import rmsnorm_impl # noqa: F401\n") + + if not all([aten, quack, oink]): + print("ERROR: one or more modes failed.", file=sys.stderr) + sys.exit(1) + + print_table( + "Forward — Normal Dispatch", + "Standard Python dispatch through torch.ops.aten._fused_rms_norm.", + aten, + quack, + oink, + "fwd", + ) + print_table( + "Forward + Backward — Normal Dispatch", + "Fwd + autograd backward, standard Python dispatch.", + aten, + quack, + oink, + "fwdbwd", + ) + print_table( + "Forward — CUDA Graph (zero Python overhead)", + "Kernel captured once, replayed without re-entering Python.", + aten, + quack, + oink, + "fwd_graph", + ) + print_table( + "Forward + Backward — CUDA Graph (zero Python overhead)", + "Fwd + bwd captured once, replayed without re-entering Python.", + aten, + quack, + oink, + "fwdbwd_graph", + ) + + print() + print("Done.") + + +if __name__ == "__main__": + main() diff --git a/oink/benchmarks/readme/run_sm100_suite.py b/oink/benchmarks/readme/run_sm100_suite.py index e45a837..05bd211 100644 --- a/oink/benchmarks/readme/run_sm100_suite.py +++ b/oink/benchmarks/readme/run_sm100_suite.py @@ -32,7 +32,12 @@ def _run(cmd: List[str], *, dry_run: bool) -> None: print("+", " ".join(cmd), flush=True) if dry_run: return - subprocess.run(cmd, check=True) + result = subprocess.run(cmd) + if result.returncode != 0: + print( + f"WARNING: command exited with code {result.returncode}, continuing...", + flush=True, + ) def main() -> None: diff --git a/oink/src/kernelagent_oink/__init__.py b/oink/src/kernelagent_oink/__init__.py index 14c1732..d61b60e 100644 --- a/oink/src/kernelagent_oink/__init__.py +++ b/oink/src/kernelagent_oink/__init__.py @@ -58,69 +58,123 @@ def _compute_cutedsl_arch(major: int, minor: int) -> str: return f"sm_{major}{minor}{suffix}" -def register(*, force: bool = False) -> None: - """Register Oink torch custom ops. - - - vLLM plugin mode (default): no-op unless `VLLM_USE_OINK_RMSNORM` is truthy. - - Standalone mode: pass `force=True` to register explicitly. +def _check_and_setup() -> bool: + """Check CUDA availability, SM >= 100, CuTeDSL deps, and set CUTE_DSL_ARCH. - This function must be safe to call multiple times and must not raise. vLLM - executes it in multiple processes (engine + workers). + Returns True if all checks pass, False otherwise. Does not raise. """ - global _OPS_REGISTERED - - if _OPS_REGISTERED: - return - - # Gate on the vLLM integration flag so installing the package does not - # change behavior unless explicitly enabled. For standalone usage (outside - # vLLM), callers can pass force=True to register the ops explicitly. - if not force and not _env_truthy("VLLM_USE_OINK_RMSNORM"): - return - try: import torch - except Exception as e: # pragma: no cover + except Exception as e: logger.debug("Oink plugin: torch import failed: %s", e) - return + return False try: if not torch.cuda.is_available(): - logger.debug("Oink plugin: torch.cuda.is_available() is False; skipping") - return + logger.debug("Oink plugin: CUDA not available; skipping") + return False device_index = _infer_cuda_device_index() major, minor = torch.cuda.get_device_capability(device_index) sm = 10 * int(major) + int(minor) if sm < 100: - return + return False - # Ensure required deps are importable before registering ops so that vLLM - # doesn't detect ops that would later fail at first use. try: import cutlass # noqa: F401 import cuda.bindings.driver as _cuda # noqa: F401 except Exception as e: logger.warning( - "Oink plugin: CuTeDSL deps missing; skipping op registration. " + "Oink plugin: CuTeDSL deps missing; skipping. " "Install `nvidia-cutlass-dsl` + `cuda-python`. Error: %s", e, ) - return + return False - # Ensure CuTeDSL sees a target arch early. If the user has already set it, - # respect their choice. os.environ.setdefault( "CUTE_DSL_ARCH", _compute_cutedsl_arch(int(major), int(minor)) ) + return True + except Exception as e: + logger.exception("Oink plugin: setup failed: %s", e) + return False + + +def register(*, force: bool = False) -> None: + """Register Oink torch custom ops (``torch.ops.oink.*``). + + This registers ``torch.ops.oink.rmsnorm`` and + ``torch.ops.oink.fused_add_rms_norm`` for use by vLLM's direct-call path. + It does NOT override aten ops — use :func:`register_all_kernels` for that. + + - vLLM plugin mode (default): no-op unless ``VLLM_USE_OINK_RMSNORM`` is truthy. + - Standalone mode: pass ``force=True`` to register explicitly. + """ + global _OPS_REGISTERED + + if _OPS_REGISTERED: + return + + if not force and not _env_truthy("VLLM_USE_OINK_RMSNORM"): + return - # Import registers the ops via torch.library.custom_op decorators. + if not _check_and_setup(): + return + + try: from .blackwell import oink_custom_ops # noqa: F401 - except Exception as e: # pragma: no cover - # Do not raise: vLLM plugin loader does not guard plugin execution. - logger.exception("Oink plugin: failed to register ops: %s", e) + except Exception as e: + logger.exception("Oink plugin: failed to register custom ops: %s", e) return _OPS_REGISTERED = True -__all__ = ["register"] +_ALL_KERNELS_REGISTERED = False + + +def register_all_kernels(*, force: bool = False) -> None: + """Override aten ops with Oink's kernels. + + Checks CUDA/SM100/deps, sets up the CuTeDSL environment, then overrides + ``aten::_fused_rms_norm`` and ``aten::_fused_rms_norm_backward`` on CUDA. + + Does NOT register ``torch.ops.oink.*`` custom ops — use :func:`register` + separately if those are needed (e.g. for vLLM's direct-call path). + + Args: + force: If *True*, bypass the ``VLLM_USE_OINK_RMSNORM`` env gate. + """ + global _ALL_KERNELS_REGISTERED + if _ALL_KERNELS_REGISTERED: + return + + if not force and not _env_truthy("VLLM_USE_OINK_RMSNORM"): + return + + if not _check_and_setup(): + return + + try: + from .aten_override import override_all_kernels + + override_all_kernels() + except Exception as e: + logger.exception("Oink: failed to override aten ops: %s", e) + return + + _ALL_KERNELS_REGISTERED = True + + +def unregister_all_kernels() -> None: + """Remove the aten override. Can be followed by :func:`register_all_kernels`.""" + global _ALL_KERNELS_REGISTERED + try: + from .aten_override import restore_all_kernels + + restore_all_kernels() + except Exception: + pass + _ALL_KERNELS_REGISTERED = False + + +__all__ = ["register", "register_all_kernels", "unregister_all_kernels"] diff --git a/oink/src/kernelagent_oink/aten_override.py b/oink/src/kernelagent_oink/aten_override.py new file mode 100644 index 0000000..5574ceb --- /dev/null +++ b/oink/src/kernelagent_oink/aten_override.py @@ -0,0 +1,235 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Override Aten kernels with Oink's Blackwell CuTeDSL Kernels. + +Currently overrides: +- ``aten::_fused_rms_norm`` → ``rmsnorm_forward`` +- ``aten::_fused_rms_norm_backward`` → ``rmsnorm_backward`` + +Follows the quack PR pattern: ``with_keyset=True``, fallback via ``call_boxed``. +Calls ``rmsnorm_forward`` / ``rmsnorm_backward`` directly to get all kernel +optimizations (ptr fast-launch, atomic dW, _reduce_partial_sum_fp32). +""" + +from __future__ import annotations + +import importlib +import logging +import math +from functools import cache, partial +from typing import List, Optional, Tuple + +import torch + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Lazy imports (cached) +# --------------------------------------------------------------------------- + + +@cache +def _oink_rmsnorm(): + return importlib.import_module("kernelagent_oink.blackwell.rmsnorm") + + +# --------------------------------------------------------------------------- +# Device support (cached) +# --------------------------------------------------------------------------- + + +@cache +def _get_device_major(device: torch.device) -> int: + major, _ = torch.cuda.get_device_capability(device) + return major + + +def _is_supported(input: torch.Tensor) -> bool: + return ( + input.dtype in (torch.float16, torch.bfloat16, torch.float32) + and _get_device_major(input.device) >= 10 + and input.shape[-1] >= 128 # oink kernels require N >= 128 + ) + + +# --------------------------------------------------------------------------- +# Reshape helpers (match quack's norms.py) +# --------------------------------------------------------------------------- + + +def _reshape_2d(t: torch.Tensor, M: int, N: int) -> torch.Tensor: + if t.ndim == 2 and t.shape[0] == M and t.shape[1] == N and t.is_contiguous(): + return t + return t.reshape(M, N).contiguous() + + +def _flatten_rstd(t: torch.Tensor, M: int) -> torch.Tensor: + if t.ndim == 1 and t.shape[0] == M: + return t + if t.is_contiguous() and t.numel() == M: + return t.detach().view(M) + return t.reshape(M).contiguous() + + +# ========================================================================= +# RMSNorm forward +# ========================================================================= + + +def _fused_rms_norm_impl( + dispatch_keys: torch.DispatchKeySet, + input: torch.Tensor, + normalized_shape: List[int], + weight: Optional[torch.Tensor], + eps: Optional[float], + *, + fallback_kernel, +) -> Tuple[torch.Tensor, torch.Tensor]: + if not _is_supported(input): + return fallback_kernel.call_boxed( + dispatch_keys, input, normalized_shape, weight, eps + ) + if eps is None: + eps = 1e-6 + + input_shape = input.shape + N = math.prod(normalized_shape) + M = input.numel() // N + + x = input.reshape(M, N) + + if weight is not None and weight.ndim != 1: + weight = weight.view(N) + + y, rstd, _ = _oink_rmsnorm().rmsnorm_forward( + x, + weight=weight, + bias=None, + residual=None, + eps=eps, + store_rstd=True, + ) + + y = y.reshape(input_shape) + stat_shape = list(input_shape[: -len(normalized_shape)]) + [1] * len( + normalized_shape + ) + rstd = rstd.view(stat_shape) + return y, rstd + + +# ========================================================================= +# RMSNorm backward +# ========================================================================= + + +def _fused_rms_norm_backward_impl( + dispatch_keys: torch.DispatchKeySet, + grad_out: torch.Tensor, + input: torch.Tensor, + normalized_shape: List[int], + rstd: torch.Tensor, + weight: Optional[torch.Tensor], + output_mask: List[bool], + *, + fallback_kernel, +) -> Tuple[torch.Tensor, torch.Tensor]: + if not _is_supported(input): + return fallback_kernel.call_boxed( + dispatch_keys, + grad_out, + input, + normalized_shape, + rstd, + weight, + output_mask, + ) + + N = math.prod(normalized_shape) + M = input.numel() // N + + x = _reshape_2d(input, M, N) + dout = _reshape_2d(grad_out, M, N) + rstd_flat = _flatten_rstd(rstd, M) + + w = weight if output_mask[1] else None + dx, dw, _db, _dres = _oink_rmsnorm().rmsnorm_backward( + x, + w, + dout, + rstd_flat, + dresidual_out=None, + has_bias=False, + has_residual=False, + ) + + grad_input: torch.Tensor | None = dx.reshape(input.shape) + grad_weight: torch.Tensor | None = dw + + # Match native _fused_rms_norm_backward: return None for masked outputs. + if not output_mask[0]: + grad_input = None + if not output_mask[1]: + grad_weight = None + + return grad_input, grad_weight + + +# ========================================================================= +# Registration +# ========================================================================= + +_OVERRIDE_LIB: torch.library.Library | None = None + + +def override_all_kernels() -> None: + """Override Aten's kernels on CUDA with Oink's kernels.""" + global _OVERRIDE_LIB + if _OVERRIDE_LIB is not None: + return + + fwd_fallback = torch.library.get_kernel("aten::_fused_rms_norm", "CUDA") + bwd_fallback = torch.library.get_kernel("aten::_fused_rms_norm_backward", "CUDA") + + fwd_impl = partial(_fused_rms_norm_impl, fallback_kernel=fwd_fallback) + bwd_impl = partial(_fused_rms_norm_backward_impl, fallback_kernel=bwd_fallback) + + lib = torch.library.Library("aten", "IMPL") + lib.impl("_fused_rms_norm", fwd_impl, "CUDA", with_keyset=True, allow_override=True) + lib.impl( + "_fused_rms_norm_backward", + bwd_impl, + "CUDA", + with_keyset=True, + allow_override=True, + ) + _OVERRIDE_LIB = lib + logger.info("Oink: overrode aten::_fused_rms_norm on CUDA") + + +def restore_all_kernels() -> None: + """Remove the override and restore PyTorch's native CUDA kernels.""" + global _OVERRIDE_LIB + if _OVERRIDE_LIB is None: + return + _OVERRIDE_LIB = None + + +__all__ = [ + "override_all_kernels", + "restore_all_kernels", +] diff --git a/oink/tests/test_aten_override.py b/oink/tests/test_aten_override.py new file mode 100644 index 0000000..9b8fc5e --- /dev/null +++ b/oink/tests/test_aten_override.py @@ -0,0 +1,240 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Oink's operator overrides. + +Verifies that ``register_all_kernels`` / ``override_all_kernels`` +properly patches Oink's kernels and their backward, and that the +overridden kernels produce numerically correct results. + +Reference values are computed via pure-PyTorch math (float32 accumulation) +to avoid issues with ``call_boxed`` and stale ``SafeKernelFunction`` +references when ``torch._native`` overrides are also active. +""" + +from __future__ import annotations + +import types + +import pytest +import torch + +TEST_CUDA = torch.cuda.is_available() + +_SM = 0 +if TEST_CUDA: + _major, _minor = torch.cuda.get_device_capability(0) + _SM = 10 * _major + _minor +SM100_OR_LATER = _SM >= 100 + +requires_cuda = pytest.mark.skipif(not TEST_CUDA, reason="CUDA not available") +requires_sm100 = pytest.mark.skipif(not SM100_OR_LATER, reason="requires SM100+") + + +_OVERRIDE_APPLIED = False +if TEST_CUDA and SM100_OR_LATER: + try: + from kernelagent_oink.aten_override import override_all_kernels + + override_all_kernels() + _OVERRIDE_APPLIED = True + except Exception: + pass + +requires_override = pytest.mark.skipif( + not _OVERRIDE_APPLIED, reason="override not applied" +) + + +SHAPES = [(8, 128), (4, 8, 32), (2, 16, 512), (4, 32, 1024)] +DTYPES = [torch.float16, torch.bfloat16, torch.float32] +EPS = 1e-5 + + +def _atol_for(dtype): + if dtype == torch.bfloat16: + return 1e-1 # bf16 has 8-bit mantissa, larger rounding error + if dtype == torch.float16: + return 1e-2 # fp16 has 11-bit mantissa + return 1e-4 # fp32 + + +@requires_cuda +@requires_sm100 +def test_override_sets_library(): + """The Library object should be non-None after override.""" + from kernelagent_oink.aten_override import _OVERRIDE_LIB + + assert _OVERRIDE_LIB is not None, "override_all_kernels did not create Library" + + +@requires_cuda +@requires_sm100 +def test_custom_ops_registered(): + """torch.ops.oink.rmsnorm should be callable after register().""" + from kernelagent_oink import register + + register(force=True) + assert hasattr(torch.ops, "oink"), "torch.ops.oink namespace missing" + assert hasattr(torch.ops.oink, "rmsnorm"), "torch.ops.oink.rmsnorm missing" + + +def test_oink_availability_checks(monkeypatch: pytest.MonkeyPatch): + """Probe _is_supported with mocked CUDA.""" + from kernelagent_oink.aten_override import _get_device_major, _is_supported + + fake_tensor = types.SimpleNamespace( + is_cuda=True, + dtype=torch.float16, + device=torch.device("cuda:0"), + shape=torch.Size([32, 4096]), + ) + + # SM90 (Hopper) → not supported (SM100+ only). + monkeypatch.setattr(torch.cuda, "is_available", lambda: True) + monkeypatch.setattr(torch.cuda, "get_device_capability", lambda d: (9, 0)) + _get_device_major.cache_clear() + assert _is_supported(fake_tensor) is False + + # SM100 → supported. + monkeypatch.setattr(torch.cuda, "get_device_capability", lambda d: (10, 0)) + _get_device_major.cache_clear() + assert _is_supported(fake_tensor) is True + + # float64 → not supported. + fake_f64 = types.SimpleNamespace( + is_cuda=True, + dtype=torch.float64, + device=torch.device("cuda:0"), + shape=torch.Size([32, 4096]), + ) + assert _is_supported(fake_f64) is False + + # N < 128 → not supported (kernel requires N >= 128). + fake_small_n = types.SimpleNamespace( + is_cuda=True, + dtype=torch.float16, + device=torch.device("cuda:0"), + shape=torch.Size([32, 64]), + ) + assert _is_supported(fake_small_n) is False + + _get_device_major.cache_clear() + + +@requires_cuda +@requires_sm100 +@requires_override +@pytest.mark.parametrize("dtype", DTYPES) +def test_rmsnorm_fwd(dtype): + atol = _atol_for(dtype) + for shape in SHAPES: + normalized_shape = [shape[-1]] + x = torch.randn(*shape, dtype=dtype, device="cuda") + w = torch.randn(*normalized_shape, dtype=dtype, device="cuda") + + # Oink override. + y, rstd = torch.ops.aten._fused_rms_norm(x, normalized_shape, w, EPS) + + # Pure-torch reference (float32 accumulation). + N = shape[-1] + M = x.numel() // N + x_f32 = x.reshape(M, N).float() + rstd_ref = torch.rsqrt(x_f32.pow(2).mean(dim=-1, keepdim=True) + EPS) + y_ref = ((x_f32 * rstd_ref) * w.float()).to(dtype).reshape(shape) + + torch.testing.assert_close( + y, y_ref, atol=atol, rtol=0, msg=f"fwd y shape={shape} dtype={dtype}" + ) + + +@requires_cuda +@requires_sm100 +@requires_override +@pytest.mark.parametrize("dtype", DTYPES) +def test_rmsnorm_bwd(dtype): + for shape in SHAPES: + normalized_shape = [shape[-1]] + x = torch.randn(*shape, dtype=dtype, device="cuda", requires_grad=True) + w = torch.randn( + *normalized_shape, dtype=dtype, device="cuda", requires_grad=True + ) + grad_out = torch.randn(*shape, dtype=dtype, device="cuda") + + # Oink override fwd + bwd. + y, _ = torch.ops.aten._fused_rms_norm(x, normalized_shape, w, EPS) + y.backward(grad_out) + + assert x.grad is not None, f"x.grad is None for shape={shape}" + assert w.grad is not None, f"w.grad is None for shape={shape}" + assert x.grad.shape == x.shape + assert w.grad.shape == w.shape + assert torch.isfinite(x.grad).all(), f"x.grad has inf/nan for shape={shape}" + assert torch.isfinite(w.grad).all(), f"w.grad has inf/nan for shape={shape}" + + +@requires_cuda +@requires_sm100 +@requires_override +@pytest.mark.parametrize( + "mask", [[True, True], [True, False], [False, True], [False, False]] +) +def test_backward_output_mask(mask): + """Backward should return None for masked outputs.""" + x = torch.randn(4, 128, dtype=torch.bfloat16, device="cuda") + w = torch.randn(128, dtype=torch.bfloat16, device="cuda") + grad = torch.randn(4, 128, dtype=torch.bfloat16, device="cuda") + + _, rstd = torch.ops.aten._fused_rms_norm(x, [128], w, EPS) + + dx, dw = torch.ops.aten._fused_rms_norm_backward(grad, x, [128], rstd, w, mask) + + if not mask[0]: + assert dx is None, "dx should be None when output_mask[0]=False" + else: + assert dx is not None and dx.shape == x.shape + + if not mask[1]: + assert dw is None, "dw should be None when output_mask[1]=False" + else: + assert dw is not None and dw.shape == w.shape + + +@requires_cuda +@requires_sm100 +@requires_override +def test_float64_rmsnorm_falls_back(): + """float64 is not supported by oink — should fall back gracefully.""" + x = torch.randn(4, 32, dtype=torch.float64, device="cuda") + w = torch.randn(32, dtype=torch.float64, device="cuda") + y, rstd = torch.ops.aten._fused_rms_norm(x, [32], w, EPS) + assert y.shape == x.shape + assert y.dtype == torch.float64 + + +@requires_cuda +@requires_sm100 +def test_restore_then_reregister(): + """restore + re-register should work in the same process.""" + from kernelagent_oink import unregister_all_kernels + from kernelagent_oink.aten_override import override_all_kernels + + unregister_all_kernels() + + # After unregister, re-register should succeed. + override_all_kernels() + + from kernelagent_oink.aten_override import _OVERRIDE_LIB + + assert _OVERRIDE_LIB is not None