From cd6814f44226d03599a697f16ad7ac430c40a81c Mon Sep 17 00:00:00 2001 From: Jiannan Wang Date: Wed, 8 Apr 2026 22:19:50 -0700 Subject: [PATCH 1/6] Add register_all_kernels API and aten override for all oink kernels Adds kernelagent_oink.register_all_kernels() which registers oink custom ops (torch.ops.oink.*) and overrides 6 aten ops at the CUDA dispatch key so that standard PyTorch APIs (F.rms_norm, F.layer_norm, F.softmax) transparently use oink's SM100 CuTeDSL kernels: - aten::_fused_rms_norm / _fused_rms_norm_backward - aten::native_layer_norm / native_layer_norm_backward - aten::_softmax / _softmax_backward_data Original CUDA kernels are captured before override for fallback on unsupported inputs (wrong dtype, SM < 100). Includes tests comparing override output against captured original kernels. --- oink/src/kernelagent_oink/__init__.py | 51 ++- oink/src/kernelagent_oink/aten_override.py | 432 ++++++++++++++++++ oink/tests/test_aten_override.py | 486 +++++++++++++++++++++ 3 files changed, 968 insertions(+), 1 deletion(-) create mode 100644 oink/src/kernelagent_oink/aten_override.py create mode 100644 oink/tests/test_aten_override.py diff --git a/oink/src/kernelagent_oink/__init__.py b/oink/src/kernelagent_oink/__init__.py index 14c1732..ee8b18f 100644 --- a/oink/src/kernelagent_oink/__init__.py +++ b/oink/src/kernelagent_oink/__init__.py @@ -123,4 +123,53 @@ def register(*, force: bool = False) -> None: _OPS_REGISTERED = True -__all__ = ["register"] +_ALL_KERNELS_REGISTERED = False + + +def register_all_kernels(*, force: bool = False) -> None: + """Register Oink custom ops *and* override PyTorch's native aten operators. + + This is the main entry point for redirecting standard PyTorch calls + (``F.rms_norm``, ``F.layer_norm``, ``F.softmax``, etc.) to Oink's SM100 + CuTeDSL kernels. It performs two steps: + + 1. Calls :func:`register` to define ``torch.ops.oink.rmsnorm`` and + ``torch.ops.oink.fused_add_rms_norm`` custom ops. + 2. Patches the following aten ops at the CUDA dispatch key so that + PyTorch's built-in operators transparently use the Oink kernels: + + - ``aten::_fused_rms_norm`` / ``aten::_fused_rms_norm_backward`` + - ``aten::native_layer_norm`` / ``aten::native_layer_norm_backward`` + - ``aten::_softmax`` / ``aten::_softmax_backward_data`` + + Unsupported inputs (wrong dtype, SM < 100) fall back to PyTorch's + original CUDA kernels automatically. + + Args: + force: If *True*, register regardless of the + ``VLLM_USE_OINK_RMSNORM`` environment variable. + """ + global _ALL_KERNELS_REGISTERED + if _ALL_KERNELS_REGISTERED: + return + + # Step 1: register torch.ops.oink.* custom ops. + register(force=force) + + if not _OPS_REGISTERED: + # register() decided to bail (missing deps, no CUDA, env gate, etc.). + return + + # Step 2: override aten ops on CUDA. + try: + from .aten_override import override_all_aten_kernels + + override_all_aten_kernels() + except Exception as e: # pragma: no cover + logger.exception("Oink: failed to override aten ops: %s", e) + return + + _ALL_KERNELS_REGISTERED = True + + +__all__ = ["register", "register_all_kernels"] diff --git a/oink/src/kernelagent_oink/aten_override.py b/oink/src/kernelagent_oink/aten_override.py new file mode 100644 index 0000000..08f103a --- /dev/null +++ b/oink/src/kernelagent_oink/aten_override.py @@ -0,0 +1,432 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Override PyTorch aten operators with Oink's Blackwell CuTeDSL kernels. + +Patches the following aten ops at the CUDA dispatch key: + +- ``aten::_fused_rms_norm`` → :func:`rmsnorm_forward` +- ``aten::_fused_rms_norm_backward`` → :func:`rmsnorm_backward` +- ``aten::native_layer_norm`` → :func:`layernorm` +- ``aten::native_layer_norm_backward`` → :func:`layernorm_backward` +- ``aten::_softmax`` → :func:`softmax_forward` +- ``aten::_softmax_backward_data`` → :func:`softmax_backward` + +All standard PyTorch APIs (``F.rms_norm``, ``F.layer_norm``, ``F.softmax``, +etc.) transparently route through the Oink kernels on SM100+ CUDA devices +after calling :func:`override_all_aten_kernels`. + +Each override captures the original CUDA kernel via +``torch.library.get_kernel`` *before* patching, so unsupported inputs +(wrong dtype, older GPU) fall back to PyTorch's native implementation. + +Usage:: + + from kernelagent_oink.aten_override import override_all_aten_kernels + + override_all_aten_kernels() + y = torch.nn.functional.rms_norm(x, [N], weight, eps) # uses Oink + + restore_all_aten_kernels() # restores PyTorch +""" + +from __future__ import annotations + +import functools +import importlib +import logging +import math +import threading +from typing import List, Optional, Tuple + +import torch + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Lazy kernel module imports +# --------------------------------------------------------------------------- + +_MOD_CACHE: dict[str, object] = {} +_MOD_LOCK = threading.Lock() + + +def _get_mod(name: str): + """Thread-safe lazy import of ``kernelagent_oink.blackwell.``.""" + cached = _MOD_CACHE.get(name) + if cached is not None: + return cached + with _MOD_LOCK: + if name not in _MOD_CACHE: + _MOD_CACHE[name] = importlib.import_module( + f"kernelagent_oink.blackwell.{name}" + ) + return _MOD_CACHE[name] + + +# --------------------------------------------------------------------------- +# Device capability helpers +# --------------------------------------------------------------------------- + + +@functools.cache +def _get_device_sm(device: torch.device) -> int: + major, minor = torch.cuda.get_device_capability(device) + return 10 * major + minor + + +_SUPPORTED_DTYPES = (torch.float16, torch.bfloat16, torch.float32) + + +def _is_supported(t: torch.Tensor) -> bool: + """True when Oink's SM100 kernel can handle this tensor.""" + return ( + t.is_cuda + and t.dtype in _SUPPORTED_DTYPES + and _get_device_sm(t.device) >= 100 + ) + + +# --------------------------------------------------------------------------- +# Fallback kernel capture +# --------------------------------------------------------------------------- + +_fallbacks: dict[str, object] = {} + + +def _capture_fallback(op_name: str, dispatch_key: str = "CUDA") -> None: + """Snapshot the current CUDA kernel for ``aten::`` before we + overwrite it. Must be called *before* ``lib.impl``.""" + if op_name in _fallbacks: + return + try: + _fallbacks[op_name] = torch.library.get_kernel( + f"aten::{op_name}", dispatch_key + ) + except Exception: + _fallbacks[op_name] = None + + +def _call_fallback(op_name: str, *args): + fb = _fallbacks.get(op_name) + if fb is not None: + return fb(*args) + raise RuntimeError( + f"Oink: no fallback captured for aten::{op_name} and input is unsupported" + ) + + +# --------------------------------------------------------------------------- +# Reshape helpers +# --------------------------------------------------------------------------- + + +def _reshape_2d(t: torch.Tensor, M: int, N: int) -> torch.Tensor: + if t.ndim == 2 and t.shape == (M, N) and t.is_contiguous(): + return t + return t.reshape(M, N).contiguous() + + +def _flatten_1d(t: torch.Tensor, M: int) -> torch.Tensor: + if t.ndim == 1 and t.shape[0] == M: + return t + if t.is_contiguous() and t.numel() == M: + return t.detach().view(M) + return t.reshape(M).contiguous() + + +def _stat_shape(input_shape, normalized_shape_len: int) -> list[int]: + """Shape for rstd / mean: ``[*batch_dims, 1, 1, ...]`` with + ``normalized_shape_len`` trailing ones.""" + return list(input_shape[:-normalized_shape_len]) + [1] * normalized_shape_len + + +# ========================================================================= +# RMSNorm +# ========================================================================= + + +def _oink_fused_rms_norm( + input: torch.Tensor, + normalized_shape: List[int], + weight: Optional[torch.Tensor], + eps: Optional[float], +) -> Tuple[torch.Tensor, torch.Tensor]: + if not _is_supported(input): + return _call_fallback( + "_fused_rms_norm", input, normalized_shape, weight, eps + ) + + if eps is None: + eps = 1e-6 + + input_shape = input.shape + N = math.prod(normalized_shape) + M = input.numel() // N + + x = _reshape_2d(input, M, N) + + mod = _get_mod("rmsnorm") + y, rstd, _ = mod.rmsnorm_forward( + x, weight=weight, bias=None, residual=None, eps=eps, store_rstd=True, + ) + + y = y.reshape(input_shape) + rstd = rstd.view(_stat_shape(input_shape, len(normalized_shape))) + return y, rstd + + +def _oink_fused_rms_norm_backward( + grad_out: torch.Tensor, + input: torch.Tensor, + normalized_shape: List[int], + rstd: torch.Tensor, + weight: Optional[torch.Tensor], + output_mask: List[bool], +) -> Tuple[torch.Tensor, torch.Tensor]: + if not _is_supported(input): + return _call_fallback( + "_fused_rms_norm_backward", + grad_out, input, normalized_shape, rstd, weight, output_mask, + ) + + N = math.prod(normalized_shape) + M = input.numel() // N + + x = _reshape_2d(input, M, N) + dout = _reshape_2d(grad_out, M, N) + rstd_flat = _flatten_1d(rstd, M) + + mod = _get_mod("rmsnorm") + dx, dw, _dbias, _dres = mod.rmsnorm_backward( + x, weight, dout, rstd_flat, + dresidual_out=None, has_bias=False, has_residual=False, + ) + + dx = dx.reshape(input.shape) + + if not output_mask[0]: + dx = torch.zeros_like(input) + if not output_mask[1] or dw is None: + dw = torch.zeros_like(weight) if weight is not None else torch.empty(0) + + return dx, dw + + +# ========================================================================= +# LayerNorm +# ========================================================================= + + +def _oink_native_layer_norm( + input: torch.Tensor, + normalized_shape: List[int], + weight: Optional[torch.Tensor], + bias: Optional[torch.Tensor], + eps: float, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if not _is_supported(input): + return _call_fallback( + "native_layer_norm", input, normalized_shape, weight, bias, eps + ) + + input_shape = input.shape + N = math.prod(normalized_shape) + M = input.numel() // N + + x = _reshape_2d(input, M, N) + + mod = _get_mod("layernorm") + out, rstd, mean = mod.layernorm( + x, weight, bias=bias, eps=eps, return_rstd=True, return_mean=True, + ) + + out = out.reshape(input_shape) + stat_sh = _stat_shape(input_shape, len(normalized_shape)) + mean = mean.view(stat_sh) + rstd = rstd.view(stat_sh) + return out, mean, rstd + + +def _oink_native_layer_norm_backward( + grad_out: torch.Tensor, + input: torch.Tensor, + normalized_shape: List[int], + mean: torch.Tensor, + rstd: torch.Tensor, + weight: Optional[torch.Tensor], + bias: Optional[torch.Tensor], + output_mask: List[bool], +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if not _is_supported(input): + return _call_fallback( + "native_layer_norm_backward", + grad_out, input, normalized_shape, mean, rstd, weight, bias, + output_mask, + ) + + N = math.prod(normalized_shape) + M = input.numel() // N + + x = _reshape_2d(input, M, N) + dout = _reshape_2d(grad_out, M, N) + mean_flat = _flatten_1d(mean, M) + rstd_flat = _flatten_1d(rstd, M) + + mod = _get_mod("layernorm") + dx, dw, db = mod.layernorm_backward( + dout, x, weight, rstd_flat, mean_flat, bias=bias, + ) + + dx = dx.reshape(input.shape) if dx is not None else torch.zeros_like(input) + + if not output_mask[0]: + dx = torch.zeros_like(input) + if not output_mask[1] or dw is None: + dw = torch.zeros_like(weight) if weight is not None else torch.empty(0) + if not output_mask[2] or db is None: + db = torch.zeros_like(bias) if bias is not None else torch.empty(0) + + return dx, dw, db + + +# ========================================================================= +# Softmax +# ========================================================================= + + +def _oink_softmax( + self: torch.Tensor, + dim: int, + half_to_float: bool, +) -> torch.Tensor: + # Oink's softmax only handles the last dimension on 2D inputs. + # Fall back for other dims or when half_to_float is requested. + ndim = self.ndim + actual_dim = dim if dim >= 0 else dim + ndim + + if not _is_supported(self) or actual_dim != ndim - 1 or half_to_float: + return _call_fallback("_softmax", self, dim, half_to_float) + + input_shape = self.shape + N = input_shape[-1] + M = self.numel() // N + + x = _reshape_2d(self, M, N) + + mod = _get_mod("softmax") + y = mod.softmax_forward(x) + + return y.reshape(input_shape) + + +def _oink_softmax_backward( + grad_output: torch.Tensor, + output: torch.Tensor, + dim: int, + input_dtype: torch.dtype, +) -> torch.Tensor: + ndim = output.ndim + actual_dim = dim if dim >= 0 else dim + ndim + + if ( + not _is_supported(output) + or actual_dim != ndim - 1 + or input_dtype != output.dtype # half_to_float case + ): + return _call_fallback( + "_softmax_backward_data", grad_output, output, dim, input_dtype + ) + + input_shape = output.shape + N = input_shape[-1] + M = output.numel() // N + + dy = _reshape_2d(grad_output, M, N) + y = _reshape_2d(output, M, N) + + mod = _get_mod("softmax") + dx = mod.softmax_backward(dy, y) + + return dx.reshape(input_shape) + + +# ========================================================================= +# Registration +# ========================================================================= + +_ATEN_LIB: torch.library.Library | None = None + +# Mapping: (aten_op_name, impl_function) +_OVERRIDES = [ + ("_fused_rms_norm", _oink_fused_rms_norm), + ("_fused_rms_norm_backward", _oink_fused_rms_norm_backward), + ("native_layer_norm", _oink_native_layer_norm), + ("native_layer_norm_backward", _oink_native_layer_norm_backward), + ("_softmax", _oink_softmax), + ("_softmax_backward_data", _oink_softmax_backward), +] + + +def override_all_aten_kernels() -> None: + """Patch all supported aten ops on the CUDA dispatch key to use Oink's + SM100 CuTeDSL kernels. + + Idempotent — safe to call multiple times. Captures the original CUDA + kernels before overriding so that unsupported inputs (wrong dtype, older + GPU) fall back transparently. + """ + global _ATEN_LIB + if _ATEN_LIB is not None: + return + + # Capture original kernels *before* we overwrite them. + for op_name, _ in _OVERRIDES: + _capture_fallback(op_name) + + lib = torch.library.Library("aten", "IMPL") + registered = [] + + for op_name, impl_fn in _OVERRIDES: + try: + lib.impl(op_name, impl_fn, "CUDA") + registered.append(op_name) + except Exception as e: + logger.warning("Oink: could not override aten::%s: %s", op_name, e) + + _ATEN_LIB = lib + logger.info("Oink: overrode %d aten ops on CUDA: %s", len(registered), registered) + + +def restore_all_aten_kernels() -> None: + """Remove all Oink overrides and restore PyTorch's native CUDA kernels.""" + global _ATEN_LIB + if _ATEN_LIB is None: + return + _ATEN_LIB = None + logger.info("Oink: restored all aten ops to PyTorch defaults") + + +# Keep the old single-op API for backward compatibility. +override_aten_rmsnorm = override_all_aten_kernels +restore_aten_rmsnorm = restore_all_aten_kernels + + +__all__ = [ + "override_all_aten_kernels", + "restore_all_aten_kernels", + "override_aten_rmsnorm", + "restore_aten_rmsnorm", +] diff --git a/oink/tests/test_aten_override.py b/oink/tests/test_aten_override.py new file mode 100644 index 0000000..b7aac84 --- /dev/null +++ b/oink/tests/test_aten_override.py @@ -0,0 +1,486 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Oink's aten operator overrides. + +Verifies that ``register_all_kernels`` / ``override_all_aten_kernels`` +properly patches PyTorch's aten ops and that the overridden kernels produce +numerically correct results compared to PyTorch's native CUDA kernels. + +The correctness tests capture the original CUDA kernel *before* the override +is applied, then compare the override's output against it. This mirrors the +approach used in PyTorch's own quack override tests. +""" + +from __future__ import annotations + +import math +import unittest + +import torch + +# --------------------------------------------------------------------------- +# Skip helpers +# --------------------------------------------------------------------------- + +TEST_CUDA = torch.cuda.is_available() + +_SM = 0 +if TEST_CUDA: + _major, _minor = torch.cuda.get_device_capability(0) + _SM = 10 * _major + _minor + +SM100_OR_LATER = _SM >= 100 + +# --------------------------------------------------------------------------- +# Capture original CUDA kernels *before* any override is applied. +# --------------------------------------------------------------------------- + +_orig_fused_rms_norm = None +_orig_fused_rms_norm_bwd = None +_orig_native_layer_norm = None +_orig_native_layer_norm_bwd = None +_orig_softmax = None +_orig_softmax_bwd = None + +if TEST_CUDA: + try: + _orig_fused_rms_norm = torch.library.get_kernel( + "aten::_fused_rms_norm", "CUDA" + ) + _orig_fused_rms_norm_bwd = torch.library.get_kernel( + "aten::_fused_rms_norm_backward", "CUDA" + ) + _orig_native_layer_norm = torch.library.get_kernel( + "aten::native_layer_norm", "CUDA" + ) + _orig_native_layer_norm_bwd = torch.library.get_kernel( + "aten::native_layer_norm_backward", "CUDA" + ) + _orig_softmax = torch.library.get_kernel("aten::_softmax", "CUDA") + _orig_softmax_bwd = torch.library.get_kernel( + "aten::_softmax_backward_data", "CUDA" + ) + except Exception: + pass + +# --------------------------------------------------------------------------- +# Apply the override (module-level so it happens once). +# --------------------------------------------------------------------------- + +_OVERRIDE_APPLIED = False + +if TEST_CUDA and SM100_OR_LATER: + try: + from kernelagent_oink.aten_override import ( + _ATEN_LIB, + _fallbacks, + override_all_aten_kernels, + ) + + override_all_aten_kernels() + _OVERRIDE_APPLIED = True + except Exception: + pass + + +# ========================================================================= +# Registration tests +# ========================================================================= + + +class TestAtenOverrideRegistration(unittest.TestCase): + """Verify that override_all_aten_kernels sets up state correctly.""" + + @unittest.skipIf(not TEST_CUDA, "CUDA not available") + @unittest.skipIf(not SM100_OR_LATER, "requires SM100+") + def test_override_sets_library(self): + """The aten Library object should be non-None after override.""" + from kernelagent_oink.aten_override import _ATEN_LIB + + self.assertIsNotNone( + _ATEN_LIB, "override_all_aten_kernels did not create the Library" + ) + + @unittest.skipIf(not TEST_CUDA, "CUDA not available") + @unittest.skipIf(not SM100_OR_LATER, "requires SM100+") + def test_all_fallbacks_captured(self): + """All 6 fallback kernels should have been captured.""" + from kernelagent_oink.aten_override import _fallbacks + + expected_ops = [ + "_fused_rms_norm", + "_fused_rms_norm_backward", + "native_layer_norm", + "native_layer_norm_backward", + "_softmax", + "_softmax_backward_data", + ] + for op in expected_ops: + self.assertIn(op, _fallbacks, f"fallback not captured for {op}") + self.assertIsNotNone(_fallbacks[op], f"fallback is None for {op}") + + @unittest.skipIf(not TEST_CUDA, "CUDA not available") + @unittest.skipIf(not SM100_OR_LATER, "requires SM100+") + def test_custom_ops_registered(self): + """torch.ops.oink.rmsnorm should be callable after registration.""" + from kernelagent_oink import register_all_kernels + + register_all_kernels(force=True) + self.assertTrue( + hasattr(torch.ops, "oink"), "torch.ops.oink namespace missing" + ) + self.assertTrue( + hasattr(torch.ops.oink, "rmsnorm"), "torch.ops.oink.rmsnorm missing" + ) + + +# ========================================================================= +# Correctness tests — RMSNorm +# ========================================================================= + +SHAPES = [(8, 128), (4, 8, 32), (2, 16, 512), (4, 32, 1024)] +DTYPES = [torch.float16, torch.bfloat16, torch.float32] +EPS = 1e-5 + + +def _atol_for(dtype): + if dtype in (torch.float16, torch.bfloat16): + return 1e-1 + return 1e-5 + + +@unittest.skipIf(not TEST_CUDA, "CUDA not available") +@unittest.skipIf(not SM100_OR_LATER, "requires SM100+") +@unittest.skipIf(not _OVERRIDE_APPLIED, "override not applied") +class TestRMSNormOverride(unittest.TestCase): + """Compare oink RMSNorm override against the captured ATen fallback.""" + + def _run_fwd(self, shape, dtype): + normalized_shape = [shape[-1]] + x = torch.randn(*shape, dtype=dtype, device="cuda") + w = torch.randn(*normalized_shape, dtype=dtype, device="cuda") + + # Oink (through overridden aten op) + y, rstd = torch.ops.aten._fused_rms_norm(x, normalized_shape, w, EPS) + + # Reference (captured original kernel) + y_ref, rstd_ref = _orig_fused_rms_norm(x, normalized_shape, w, EPS) + + atol = _atol_for(dtype) + torch.testing.assert_close( + y, y_ref, atol=atol, rtol=0, msg=f"fwd y shape={shape} dtype={dtype}" + ) + torch.testing.assert_close( + rstd, + rstd_ref, + atol=1e-5, + rtol=0, + msg=f"fwd rstd shape={shape} dtype={dtype}", + ) + + def test_fwd_fp16(self): + for shape in SHAPES: + self._run_fwd(shape, torch.float16) + + def test_fwd_bf16(self): + for shape in SHAPES: + self._run_fwd(shape, torch.bfloat16) + + def test_fwd_fp32(self): + for shape in SHAPES: + self._run_fwd(shape, torch.float32) + + def _run_bwd(self, shape, dtype): + normalized_shape = [shape[-1]] + x = torch.randn(*shape, dtype=dtype, device="cuda") + w = torch.randn(*normalized_shape, dtype=dtype, device="cuda") + grad_out = torch.randn(*shape, dtype=dtype, device="cuda") + + # Oink + x1 = x.detach().requires_grad_(True) + w1 = w.detach().requires_grad_(True) + y1, _ = torch.ops.aten._fused_rms_norm(x1, normalized_shape, w1, EPS) + y1.backward(grad_out) + + # Reference + x2 = x.detach().requires_grad_(True) + w2 = w.detach().requires_grad_(True) + y2, _ = _orig_fused_rms_norm(x2, normalized_shape, w2, EPS) + y2.backward(grad_out) + + atol = 3e-1 if dtype == torch.bfloat16 else _atol_for(dtype) + torch.testing.assert_close( + x1.grad, + x2.grad, + atol=atol, + rtol=0, + msg=f"bwd x_grad shape={shape} dtype={dtype}", + ) + torch.testing.assert_close( + w1.grad, + w2.grad, + atol=atol, + rtol=0, + msg=f"bwd w_grad shape={shape} dtype={dtype}", + ) + + def test_bwd_fp16(self): + for shape in SHAPES: + self._run_bwd(shape, torch.float16) + + def test_bwd_bf16(self): + for shape in SHAPES: + self._run_bwd(shape, torch.bfloat16) + + def test_bwd_fp32(self): + for shape in SHAPES: + self._run_bwd(shape, torch.float32) + + +# ========================================================================= +# Correctness tests — LayerNorm +# ========================================================================= + + +@unittest.skipIf(not TEST_CUDA, "CUDA not available") +@unittest.skipIf(not SM100_OR_LATER, "requires SM100+") +@unittest.skipIf(not _OVERRIDE_APPLIED, "override not applied") +class TestLayerNormOverride(unittest.TestCase): + """Compare oink LayerNorm override against the captured ATen fallback.""" + + def _run_fwd(self, shape, dtype): + normalized_shape = [shape[-1]] + x = torch.randn(*shape, dtype=dtype, device="cuda") + w = torch.randn(*normalized_shape, dtype=dtype, device="cuda") + b = torch.randn(*normalized_shape, dtype=dtype, device="cuda") + + # Oink + out, mean, rstd = torch.ops.aten.native_layer_norm( + x, normalized_shape, w, b, EPS + ) + + # Reference + out_ref, mean_ref, rstd_ref = _orig_native_layer_norm( + x, normalized_shape, w, b, EPS + ) + + atol = _atol_for(dtype) + torch.testing.assert_close( + out, out_ref, atol=atol, rtol=0, msg=f"fwd shape={shape} dtype={dtype}" + ) + torch.testing.assert_close( + mean, + mean_ref, + atol=1e-5, + rtol=0, + msg=f"fwd mean shape={shape} dtype={dtype}", + ) + torch.testing.assert_close( + rstd, + rstd_ref, + atol=1e-5, + rtol=0, + msg=f"fwd rstd shape={shape} dtype={dtype}", + ) + + def test_fwd_fp16(self): + for shape in SHAPES: + self._run_fwd(shape, torch.float16) + + def test_fwd_bf16(self): + for shape in SHAPES: + self._run_fwd(shape, torch.bfloat16) + + def test_fwd_fp32(self): + for shape in SHAPES: + self._run_fwd(shape, torch.float32) + + def _run_bwd(self, shape, dtype): + normalized_shape = [shape[-1]] + x = torch.randn(*shape, dtype=dtype, device="cuda") + w = torch.randn(*normalized_shape, dtype=dtype, device="cuda") + b = torch.randn(*normalized_shape, dtype=dtype, device="cuda") + grad_out = torch.randn(*shape, dtype=dtype, device="cuda") + + # Oink + x1 = x.detach().requires_grad_(True) + w1 = w.detach().requires_grad_(True) + b1 = b.detach().requires_grad_(True) + out1, _, _ = torch.ops.aten.native_layer_norm( + x1, normalized_shape, w1, b1, EPS + ) + out1.backward(grad_out) + + # Reference + x2 = x.detach().requires_grad_(True) + w2 = w.detach().requires_grad_(True) + b2 = b.detach().requires_grad_(True) + out2, _, _ = _orig_native_layer_norm(x2, normalized_shape, w2, b2, EPS) + out2.backward(grad_out) + + atol = 3e-1 if dtype == torch.bfloat16 else _atol_for(dtype) + torch.testing.assert_close( + x1.grad, + x2.grad, + atol=atol, + rtol=0, + msg=f"bwd x_grad shape={shape} dtype={dtype}", + ) + torch.testing.assert_close( + w1.grad, + w2.grad, + atol=atol, + rtol=0, + msg=f"bwd w_grad shape={shape} dtype={dtype}", + ) + torch.testing.assert_close( + b1.grad, + b2.grad, + atol=atol, + rtol=0, + msg=f"bwd b_grad shape={shape} dtype={dtype}", + ) + + def test_bwd_fp16(self): + for shape in SHAPES: + self._run_bwd(shape, torch.float16) + + def test_bwd_bf16(self): + for shape in SHAPES: + self._run_bwd(shape, torch.bfloat16) + + def test_bwd_fp32(self): + for shape in SHAPES: + self._run_bwd(shape, torch.float32) + + +# ========================================================================= +# Correctness tests — Softmax +# ========================================================================= + + +@unittest.skipIf(not TEST_CUDA, "CUDA not available") +@unittest.skipIf(not SM100_OR_LATER, "requires SM100+") +@unittest.skipIf(not _OVERRIDE_APPLIED, "override not applied") +class TestSoftmaxOverride(unittest.TestCase): + """Compare oink Softmax override against the captured ATen fallback.""" + + def _run_fwd(self, shape, dtype): + x = torch.randn(*shape, dtype=dtype, device="cuda") + + # Oink (dim=-1, half_to_float=False) + y = torch.ops.aten._softmax(x, -1, False) + + # Reference + y_ref = _orig_softmax(x, -1, False) + + atol = _atol_for(dtype) + torch.testing.assert_close( + y, y_ref, atol=atol, rtol=0, msg=f"fwd shape={shape} dtype={dtype}" + ) + + def test_fwd_fp16(self): + for shape in SHAPES: + self._run_fwd(shape, torch.float16) + + def test_fwd_bf16(self): + for shape in SHAPES: + self._run_fwd(shape, torch.bfloat16) + + def test_fwd_fp32(self): + for shape in SHAPES: + self._run_fwd(shape, torch.float32) + + def _run_bwd(self, shape, dtype): + x = torch.randn(*shape, dtype=dtype, device="cuda") + grad_out = torch.randn(*shape, dtype=dtype, device="cuda") + + # Oink + x1 = x.detach().requires_grad_(True) + y1 = torch.softmax(x1, dim=-1) + y1.backward(grad_out) + + # Reference (manual softmax + bwd to avoid the override) + x2 = x.detach().requires_grad_(True) + y2 = _orig_softmax(x2, -1, False) + dx_ref = _orig_softmax_bwd(grad_out, y2, -1, dtype) + + atol = 3e-1 if dtype == torch.bfloat16 else _atol_for(dtype) + torch.testing.assert_close( + x1.grad, + dx_ref, + atol=atol, + rtol=0, + msg=f"bwd shape={shape} dtype={dtype}", + ) + + def test_bwd_fp16(self): + for shape in SHAPES: + self._run_bwd(shape, torch.float16) + + def test_bwd_bf16(self): + for shape in SHAPES: + self._run_bwd(shape, torch.bfloat16) + + def test_bwd_fp32(self): + for shape in SHAPES: + self._run_bwd(shape, torch.float32) + + +# ========================================================================= +# Fallback tests +# ========================================================================= + + +@unittest.skipIf(not TEST_CUDA, "CUDA not available") +@unittest.skipIf(not SM100_OR_LATER, "requires SM100+") +@unittest.skipIf(not _OVERRIDE_APPLIED, "override not applied") +class TestFallback(unittest.TestCase): + """Verify that unsupported inputs fall back to the native CUDA kernel.""" + + def test_float64_rmsnorm_falls_back(self): + """float64 is not supported by oink — should fall back gracefully.""" + x = torch.randn(4, 32, dtype=torch.float64, device="cuda") + w = torch.randn(32, dtype=torch.float64, device="cuda") + y, rstd = torch.ops.aten._fused_rms_norm(x, [32], w, EPS) + self.assertEqual(y.shape, x.shape) + self.assertEqual(y.dtype, torch.float64) + + def test_float64_layernorm_falls_back(self): + x = torch.randn(4, 32, dtype=torch.float64, device="cuda") + w = torch.randn(32, dtype=torch.float64, device="cuda") + b = torch.randn(32, dtype=torch.float64, device="cuda") + out, mean, rstd = torch.ops.aten.native_layer_norm(x, [32], w, b, EPS) + self.assertEqual(out.shape, x.shape) + self.assertEqual(out.dtype, torch.float64) + + def test_float64_softmax_falls_back(self): + x = torch.randn(4, 32, dtype=torch.float64, device="cuda") + y = torch.ops.aten._softmax(x, -1, False) + self.assertEqual(y.shape, x.shape) + self.assertEqual(y.dtype, torch.float64) + + def test_non_last_dim_softmax_falls_back(self): + """Softmax on dim=0 should fall back (oink only handles last dim).""" + x = torch.randn(4, 32, dtype=torch.float16, device="cuda") + y = torch.ops.aten._softmax(x, 0, False) + self.assertEqual(y.shape, x.shape) + # Verify correctness: softmax on dim=0 + y_ref = _orig_softmax(x, 0, False) + torch.testing.assert_close(y, y_ref, atol=1e-3, rtol=0) + + +if __name__ == "__main__": + unittest.main() From b18524872ccb66e1f17b0f9604c9ced0cb932ff9 Mon Sep 17 00:00:00 2001 From: Jiannan Wang Date: Mon, 13 Apr 2026 10:46:33 -0700 Subject: [PATCH 2/6] Add vLLM-style stride guards and switch tests to pytest - Add _can_view_as_2d() and _is_oink_stride_compatible_2d() stride checks matching vLLM's layernorm.py pattern. Overrides now fall back to the native CUDA kernel when the tensor layout doesn't meet Oink's pointer-path constraints (stride(1)==1, 256-bit alignment). - Switch tests from unittest to pytest style (plain functions, @pytest.mark.parametrize, monkeypatch for availability checks). - Add stride guard unit tests: _can_view_as_2d, stride compatibility, and mocked _is_supported availability probes. --- oink/src/kernelagent_oink/aten_override.py | 71 ++- oink/tests/test_aten_override.py | 476 ++++++++++----------- 2 files changed, 296 insertions(+), 251 deletions(-) diff --git a/oink/src/kernelagent_oink/aten_override.py b/oink/src/kernelagent_oink/aten_override.py index 08f103a..c2b3ecd 100644 --- a/oink/src/kernelagent_oink/aten_override.py +++ b/oink/src/kernelagent_oink/aten_override.py @@ -129,16 +129,69 @@ def _call_fallback(op_name: str, *args): # --------------------------------------------------------------------------- -# Reshape helpers +# Reshape / stride helpers # --------------------------------------------------------------------------- +def _can_view_as_2d(x: torch.Tensor) -> bool: + """Return True if ``x.view(-1, x.shape[-1])`` is viewable (no copy). + + For a view(-1, N) to be valid, all leading dims must be contiguous with + respect to each other (size-1 dims are ignored). + """ + if x.dim() < 2: + return False + if x.dim() == 2: + return True + for dim in range(x.dim() - 1): + if x.size(dim + 1) != 1 and x.stride(dim) != x.stride(dim + 1) * x.size( + dim + 1 + ): + return False + return True + + +def _is_oink_stride_compatible_2d(x_2d: torch.Tensor) -> bool: + """Return True if *x_2d* meets Oink's pointer-path stride constraints. + + Requires stride(1) == 1 (row-major last dim) and stride(0) divisible by + the vectorization granularity (256 bits). + """ + if x_2d.dim() != 2: + return False + if x_2d.stride(1) != 1: + return False + if x_2d.dtype in (torch.float16, torch.bfloat16): + divby = 16 # 256 bits / 16 bits = 16 elements + elif x_2d.dtype == torch.float32: + divby = 8 # 256 bits / 32 bits = 8 elements + else: + return False + return (x_2d.stride(0) % divby) == 0 + + def _reshape_2d(t: torch.Tensor, M: int, N: int) -> torch.Tensor: if t.ndim == 2 and t.shape == (M, N) and t.is_contiguous(): return t return t.reshape(M, N).contiguous() +def _reshape_2d_checked(t: torch.Tensor, M: int, N: int) -> torch.Tensor | None: + """Reshape to 2D and return None if the result doesn't meet Oink's stride + constraints. Callers should fall back to the native kernel on None.""" + if not _can_view_as_2d(t): + return None + x_2d = t.view(-1, N) if t.dim() > 2 else t + if x_2d.shape != (M, N): + x_2d = t.reshape(M, N) + if not x_2d.is_contiguous() and not _is_oink_stride_compatible_2d(x_2d): + # contiguous() always produces stride-compatible layout. + x_2d = x_2d.contiguous() + if not _is_oink_stride_compatible_2d(x_2d): + return None + return x_2d + + def _flatten_1d(t: torch.Tensor, M: int) -> torch.Tensor: if t.ndim == 1 and t.shape[0] == M: return t @@ -176,7 +229,11 @@ def _oink_fused_rms_norm( N = math.prod(normalized_shape) M = input.numel() // N - x = _reshape_2d(input, M, N) + x = _reshape_2d_checked(input, M, N) + if x is None: + return _call_fallback( + "_fused_rms_norm", input, normalized_shape, weight, eps + ) mod = _get_mod("rmsnorm") y, rstd, _ = mod.rmsnorm_forward( @@ -246,7 +303,11 @@ def _oink_native_layer_norm( N = math.prod(normalized_shape) M = input.numel() // N - x = _reshape_2d(input, M, N) + x = _reshape_2d_checked(input, M, N) + if x is None: + return _call_fallback( + "native_layer_norm", input, normalized_shape, weight, bias, eps + ) mod = _get_mod("layernorm") out, rstd, mean = mod.layernorm( @@ -324,7 +385,9 @@ def _oink_softmax( N = input_shape[-1] M = self.numel() // N - x = _reshape_2d(self, M, N) + x = _reshape_2d_checked(self, M, N) + if x is None: + return _call_fallback("_softmax", self, dim, half_to_float) mod = _get_mod("softmax") y = mod.softmax_forward(x) diff --git a/oink/tests/test_aten_override.py b/oink/tests/test_aten_override.py index b7aac84..3efa93c 100644 --- a/oink/tests/test_aten_override.py +++ b/oink/tests/test_aten_override.py @@ -19,15 +19,15 @@ numerically correct results compared to PyTorch's native CUDA kernels. The correctness tests capture the original CUDA kernel *before* the override -is applied, then compare the override's output against it. This mirrors the -approach used in PyTorch's own quack override tests. +is applied, then compare the override's output against it. """ from __future__ import annotations import math -import unittest +import types +import pytest import torch # --------------------------------------------------------------------------- @@ -43,6 +43,9 @@ SM100_OR_LATER = _SM >= 100 +requires_cuda = pytest.mark.skipif(not TEST_CUDA, reason="CUDA not available") +requires_sm100 = pytest.mark.skipif(not SM100_OR_LATER, reason="requires SM100+") + # --------------------------------------------------------------------------- # Capture original CUDA kernels *before* any override is applied. # --------------------------------------------------------------------------- @@ -76,245 +79,270 @@ pass # --------------------------------------------------------------------------- -# Apply the override (module-level so it happens once). +# Apply the override (module-level, happens once). # --------------------------------------------------------------------------- _OVERRIDE_APPLIED = False if TEST_CUDA and SM100_OR_LATER: try: - from kernelagent_oink.aten_override import ( - _ATEN_LIB, - _fallbacks, - override_all_aten_kernels, - ) + from kernelagent_oink.aten_override import override_all_aten_kernels override_all_aten_kernels() _OVERRIDE_APPLIED = True except Exception: pass +requires_override = pytest.mark.skipif( + not _OVERRIDE_APPLIED, reason="override not applied" +) + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +SHAPES = [(8, 128), (4, 8, 32), (2, 16, 512), (4, 32, 1024)] +DTYPES = [torch.float16, torch.bfloat16, torch.float32] +EPS = 1e-5 + + +def _atol_for(dtype): + if dtype in (torch.float16, torch.bfloat16): + return 1e-1 + return 1e-5 + # ========================================================================= # Registration tests # ========================================================================= -class TestAtenOverrideRegistration(unittest.TestCase): - """Verify that override_all_aten_kernels sets up state correctly.""" +@requires_cuda +@requires_sm100 +def test_override_sets_library(): + """The aten Library object should be non-None after override.""" + from kernelagent_oink.aten_override import _ATEN_LIB - @unittest.skipIf(not TEST_CUDA, "CUDA not available") - @unittest.skipIf(not SM100_OR_LATER, "requires SM100+") - def test_override_sets_library(self): - """The aten Library object should be non-None after override.""" - from kernelagent_oink.aten_override import _ATEN_LIB + assert _ATEN_LIB is not None, "override_all_aten_kernels did not create Library" - self.assertIsNotNone( - _ATEN_LIB, "override_all_aten_kernels did not create the Library" - ) - @unittest.skipIf(not TEST_CUDA, "CUDA not available") - @unittest.skipIf(not SM100_OR_LATER, "requires SM100+") - def test_all_fallbacks_captured(self): - """All 6 fallback kernels should have been captured.""" - from kernelagent_oink.aten_override import _fallbacks - - expected_ops = [ - "_fused_rms_norm", - "_fused_rms_norm_backward", - "native_layer_norm", - "native_layer_norm_backward", - "_softmax", - "_softmax_backward_data", - ] - for op in expected_ops: - self.assertIn(op, _fallbacks, f"fallback not captured for {op}") - self.assertIsNotNone(_fallbacks[op], f"fallback is None for {op}") - - @unittest.skipIf(not TEST_CUDA, "CUDA not available") - @unittest.skipIf(not SM100_OR_LATER, "requires SM100+") - def test_custom_ops_registered(self): - """torch.ops.oink.rmsnorm should be callable after registration.""" - from kernelagent_oink import register_all_kernels - - register_all_kernels(force=True) - self.assertTrue( - hasattr(torch.ops, "oink"), "torch.ops.oink namespace missing" - ) - self.assertTrue( - hasattr(torch.ops.oink, "rmsnorm"), "torch.ops.oink.rmsnorm missing" - ) +@requires_cuda +@requires_sm100 +def test_all_fallbacks_captured(): + """All 6 fallback kernels should have been captured.""" + from kernelagent_oink.aten_override import _fallbacks + + expected_ops = [ + "_fused_rms_norm", + "_fused_rms_norm_backward", + "native_layer_norm", + "native_layer_norm_backward", + "_softmax", + "_softmax_backward_data", + ] + for op in expected_ops: + assert op in _fallbacks, f"fallback not captured for {op}" + assert _fallbacks[op] is not None, f"fallback is None for {op}" + + +@requires_cuda +@requires_sm100 +def test_custom_ops_registered(): + """torch.ops.oink.rmsnorm should be callable after registration.""" + from kernelagent_oink import register_all_kernels + + register_all_kernels(force=True) + assert hasattr(torch.ops, "oink"), "torch.ops.oink namespace missing" + assert hasattr(torch.ops.oink, "rmsnorm"), "torch.ops.oink.rmsnorm missing" # ========================================================================= -# Correctness tests — RMSNorm +# Availability / stride-guard tests (no GPU required for some) # ========================================================================= -SHAPES = [(8, 128), (4, 8, 32), (2, 16, 512), (4, 32, 1024)] -DTYPES = [torch.float16, torch.bfloat16, torch.float32] -EPS = 1e-5 +def test_oink_availability_checks(monkeypatch: pytest.MonkeyPatch): + """Probe is_oink_available_for_device with mocked CUDA.""" + from kernelagent_oink.aten_override import _is_supported -def _atol_for(dtype): - if dtype in (torch.float16, torch.bfloat16): - return 1e-1 - return 1e-5 + # Mock a CUDA tensor with SM90 (below threshold). + fake_tensor = types.SimpleNamespace( + is_cuda=True, dtype=torch.float16, device=torch.device("cuda:0") + ) + + # SM90 → not supported. + monkeypatch.setattr(torch.cuda, "is_available", lambda: True) + monkeypatch.setattr(torch.cuda, "get_device_capability", lambda d: (9, 0)) + # Clear the cached SM value. + from kernelagent_oink.aten_override import _get_device_sm + + _get_device_sm.cache_clear() + assert _is_supported(fake_tensor) is False + + # SM100 → supported. + monkeypatch.setattr(torch.cuda, "get_device_capability", lambda d: (10, 0)) + _get_device_sm.cache_clear() + assert _is_supported(fake_tensor) is True + + # float64 → not supported even on SM100. + fake_f64 = types.SimpleNamespace( + is_cuda=True, dtype=torch.float64, device=torch.device("cuda:0") + ) + assert _is_supported(fake_f64) is False + + _get_device_sm.cache_clear() + + +def test_can_view_as_2d_stride_guard(): + """Verify _can_view_as_2d correctly identifies non-viewable layouts.""" + from kernelagent_oink.aten_override import _can_view_as_2d + x = torch.zeros((2, 3, 4)) + assert _can_view_as_2d(x) is True -@unittest.skipIf(not TEST_CUDA, "CUDA not available") -@unittest.skipIf(not SM100_OR_LATER, "requires SM100+") -@unittest.skipIf(not _OVERRIDE_APPLIED, "override not applied") -class TestRMSNormOverride(unittest.TestCase): - """Compare oink RMSNorm override against the captured ATen fallback.""" + # Size-1 dims should be ignored by the viewability check. + base = torch.zeros((2, 10, 4)) + x_singleton = base[:, :1, :] + assert _can_view_as_2d(x_singleton) is True - def _run_fwd(self, shape, dtype): + # Middle-dimension stride break: view(-1, hidden) should be invalid. + x2 = x[:, ::2, :] + with pytest.raises(RuntimeError): + x2.view(-1, x2.shape[-1]) + assert _can_view_as_2d(x2) is False + + +def test_is_oink_stride_compatible_2d(): + """Verify vectorization alignment check.""" + from kernelagent_oink.aten_override import _is_oink_stride_compatible_2d + + # Standard contiguous tensor (stride(0)==N, stride(1)==1) → compatible. + x = torch.zeros(4, 128, dtype=torch.float16) + assert _is_oink_stride_compatible_2d(x) is True + + # Padded row: stride(0) % 16 == 0 → compatible. + base = torch.zeros(4, 256, dtype=torch.float16) + x_padded = base[:, :128] # stride(0)=256, stride(1)=1 + assert x_padded.stride(0) == 256 + assert _is_oink_stride_compatible_2d(x_padded) is True + + # 1D tensor → not compatible. + assert _is_oink_stride_compatible_2d(torch.zeros(128)) is False + + # Wrong dtype → not compatible. + assert _is_oink_stride_compatible_2d(torch.zeros(4, 128, dtype=torch.float64)) is False + + +# ========================================================================= +# Correctness tests — RMSNorm +# ========================================================================= + + +@requires_cuda +@requires_sm100 +@requires_override +@pytest.mark.parametrize("dtype", DTYPES) +def test_rmsnorm_fwd(dtype): + atol = _atol_for(dtype) + for shape in SHAPES: normalized_shape = [shape[-1]] x = torch.randn(*shape, dtype=dtype, device="cuda") w = torch.randn(*normalized_shape, dtype=dtype, device="cuda") - # Oink (through overridden aten op) y, rstd = torch.ops.aten._fused_rms_norm(x, normalized_shape, w, EPS) - - # Reference (captured original kernel) y_ref, rstd_ref = _orig_fused_rms_norm(x, normalized_shape, w, EPS) - atol = _atol_for(dtype) torch.testing.assert_close( y, y_ref, atol=atol, rtol=0, msg=f"fwd y shape={shape} dtype={dtype}" ) torch.testing.assert_close( - rstd, - rstd_ref, - atol=1e-5, - rtol=0, + rstd, rstd_ref, atol=1e-5, rtol=0, msg=f"fwd rstd shape={shape} dtype={dtype}", ) - def test_fwd_fp16(self): - for shape in SHAPES: - self._run_fwd(shape, torch.float16) - - def test_fwd_bf16(self): - for shape in SHAPES: - self._run_fwd(shape, torch.bfloat16) - def test_fwd_fp32(self): - for shape in SHAPES: - self._run_fwd(shape, torch.float32) - - def _run_bwd(self, shape, dtype): +@requires_cuda +@requires_sm100 +@requires_override +@pytest.mark.parametrize("dtype", DTYPES) +def test_rmsnorm_bwd(dtype): + atol = 3e-1 if dtype == torch.bfloat16 else _atol_for(dtype) + for shape in SHAPES: normalized_shape = [shape[-1]] x = torch.randn(*shape, dtype=dtype, device="cuda") w = torch.randn(*normalized_shape, dtype=dtype, device="cuda") grad_out = torch.randn(*shape, dtype=dtype, device="cuda") - # Oink x1 = x.detach().requires_grad_(True) w1 = w.detach().requires_grad_(True) y1, _ = torch.ops.aten._fused_rms_norm(x1, normalized_shape, w1, EPS) y1.backward(grad_out) - # Reference x2 = x.detach().requires_grad_(True) w2 = w.detach().requires_grad_(True) y2, _ = _orig_fused_rms_norm(x2, normalized_shape, w2, EPS) y2.backward(grad_out) - atol = 3e-1 if dtype == torch.bfloat16 else _atol_for(dtype) torch.testing.assert_close( - x1.grad, - x2.grad, - atol=atol, - rtol=0, + x1.grad, x2.grad, atol=atol, rtol=0, msg=f"bwd x_grad shape={shape} dtype={dtype}", ) torch.testing.assert_close( - w1.grad, - w2.grad, - atol=atol, - rtol=0, + w1.grad, w2.grad, atol=atol, rtol=0, msg=f"bwd w_grad shape={shape} dtype={dtype}", ) - def test_bwd_fp16(self): - for shape in SHAPES: - self._run_bwd(shape, torch.float16) - - def test_bwd_bf16(self): - for shape in SHAPES: - self._run_bwd(shape, torch.bfloat16) - - def test_bwd_fp32(self): - for shape in SHAPES: - self._run_bwd(shape, torch.float32) - # ========================================================================= # Correctness tests — LayerNorm # ========================================================================= -@unittest.skipIf(not TEST_CUDA, "CUDA not available") -@unittest.skipIf(not SM100_OR_LATER, "requires SM100+") -@unittest.skipIf(not _OVERRIDE_APPLIED, "override not applied") -class TestLayerNormOverride(unittest.TestCase): - """Compare oink LayerNorm override against the captured ATen fallback.""" - - def _run_fwd(self, shape, dtype): +@requires_cuda +@requires_sm100 +@requires_override +@pytest.mark.parametrize("dtype", DTYPES) +def test_layernorm_fwd(dtype): + atol = _atol_for(dtype) + for shape in SHAPES: normalized_shape = [shape[-1]] x = torch.randn(*shape, dtype=dtype, device="cuda") w = torch.randn(*normalized_shape, dtype=dtype, device="cuda") b = torch.randn(*normalized_shape, dtype=dtype, device="cuda") - # Oink out, mean, rstd = torch.ops.aten.native_layer_norm( x, normalized_shape, w, b, EPS ) - - # Reference out_ref, mean_ref, rstd_ref = _orig_native_layer_norm( x, normalized_shape, w, b, EPS ) - atol = _atol_for(dtype) torch.testing.assert_close( out, out_ref, atol=atol, rtol=0, msg=f"fwd shape={shape} dtype={dtype}" ) torch.testing.assert_close( - mean, - mean_ref, - atol=1e-5, - rtol=0, + mean, mean_ref, atol=1e-5, rtol=0, msg=f"fwd mean shape={shape} dtype={dtype}", ) torch.testing.assert_close( - rstd, - rstd_ref, - atol=1e-5, - rtol=0, + rstd, rstd_ref, atol=1e-5, rtol=0, msg=f"fwd rstd shape={shape} dtype={dtype}", ) - def test_fwd_fp16(self): - for shape in SHAPES: - self._run_fwd(shape, torch.float16) - - def test_fwd_bf16(self): - for shape in SHAPES: - self._run_fwd(shape, torch.bfloat16) - def test_fwd_fp32(self): - for shape in SHAPES: - self._run_fwd(shape, torch.float32) - - def _run_bwd(self, shape, dtype): +@requires_cuda +@requires_sm100 +@requires_override +@pytest.mark.parametrize("dtype", DTYPES) +def test_layernorm_bwd(dtype): + atol = 3e-1 if dtype == torch.bfloat16 else _atol_for(dtype) + for shape in SHAPES: normalized_shape = [shape[-1]] x = torch.randn(*shape, dtype=dtype, device="cuda") w = torch.randn(*normalized_shape, dtype=dtype, device="cuda") b = torch.randn(*normalized_shape, dtype=dtype, device="cuda") grad_out = torch.randn(*shape, dtype=dtype, device="cuda") - # Oink x1 = x.detach().requires_grad_(True) w1 = w.detach().requires_grad_(True) b1 = b.detach().requires_grad_(True) @@ -323,164 +351,118 @@ def _run_bwd(self, shape, dtype): ) out1.backward(grad_out) - # Reference x2 = x.detach().requires_grad_(True) w2 = w.detach().requires_grad_(True) b2 = b.detach().requires_grad_(True) out2, _, _ = _orig_native_layer_norm(x2, normalized_shape, w2, b2, EPS) out2.backward(grad_out) - atol = 3e-1 if dtype == torch.bfloat16 else _atol_for(dtype) torch.testing.assert_close( - x1.grad, - x2.grad, - atol=atol, - rtol=0, + x1.grad, x2.grad, atol=atol, rtol=0, msg=f"bwd x_grad shape={shape} dtype={dtype}", ) torch.testing.assert_close( - w1.grad, - w2.grad, - atol=atol, - rtol=0, + w1.grad, w2.grad, atol=atol, rtol=0, msg=f"bwd w_grad shape={shape} dtype={dtype}", ) torch.testing.assert_close( - b1.grad, - b2.grad, - atol=atol, - rtol=0, + b1.grad, b2.grad, atol=atol, rtol=0, msg=f"bwd b_grad shape={shape} dtype={dtype}", ) - def test_bwd_fp16(self): - for shape in SHAPES: - self._run_bwd(shape, torch.float16) - - def test_bwd_bf16(self): - for shape in SHAPES: - self._run_bwd(shape, torch.bfloat16) - - def test_bwd_fp32(self): - for shape in SHAPES: - self._run_bwd(shape, torch.float32) - # ========================================================================= # Correctness tests — Softmax # ========================================================================= -@unittest.skipIf(not TEST_CUDA, "CUDA not available") -@unittest.skipIf(not SM100_OR_LATER, "requires SM100+") -@unittest.skipIf(not _OVERRIDE_APPLIED, "override not applied") -class TestSoftmaxOverride(unittest.TestCase): - """Compare oink Softmax override against the captured ATen fallback.""" - - def _run_fwd(self, shape, dtype): +@requires_cuda +@requires_sm100 +@requires_override +@pytest.mark.parametrize("dtype", DTYPES) +def test_softmax_fwd(dtype): + atol = _atol_for(dtype) + for shape in SHAPES: x = torch.randn(*shape, dtype=dtype, device="cuda") - # Oink (dim=-1, half_to_float=False) y = torch.ops.aten._softmax(x, -1, False) - - # Reference y_ref = _orig_softmax(x, -1, False) - atol = _atol_for(dtype) torch.testing.assert_close( y, y_ref, atol=atol, rtol=0, msg=f"fwd shape={shape} dtype={dtype}" ) - def test_fwd_fp16(self): - for shape in SHAPES: - self._run_fwd(shape, torch.float16) - def test_fwd_bf16(self): - for shape in SHAPES: - self._run_fwd(shape, torch.bfloat16) - - def test_fwd_fp32(self): - for shape in SHAPES: - self._run_fwd(shape, torch.float32) - - def _run_bwd(self, shape, dtype): +@requires_cuda +@requires_sm100 +@requires_override +@pytest.mark.parametrize("dtype", DTYPES) +def test_softmax_bwd(dtype): + atol = 3e-1 if dtype == torch.bfloat16 else _atol_for(dtype) + for shape in SHAPES: x = torch.randn(*shape, dtype=dtype, device="cuda") grad_out = torch.randn(*shape, dtype=dtype, device="cuda") - # Oink x1 = x.detach().requires_grad_(True) y1 = torch.softmax(x1, dim=-1) y1.backward(grad_out) - # Reference (manual softmax + bwd to avoid the override) x2 = x.detach().requires_grad_(True) y2 = _orig_softmax(x2, -1, False) dx_ref = _orig_softmax_bwd(grad_out, y2, -1, dtype) - atol = 3e-1 if dtype == torch.bfloat16 else _atol_for(dtype) torch.testing.assert_close( - x1.grad, - dx_ref, - atol=atol, - rtol=0, + x1.grad, dx_ref, atol=atol, rtol=0, msg=f"bwd shape={shape} dtype={dtype}", ) - def test_bwd_fp16(self): - for shape in SHAPES: - self._run_bwd(shape, torch.float16) - - def test_bwd_bf16(self): - for shape in SHAPES: - self._run_bwd(shape, torch.bfloat16) - - def test_bwd_fp32(self): - for shape in SHAPES: - self._run_bwd(shape, torch.float32) - # ========================================================================= # Fallback tests # ========================================================================= -@unittest.skipIf(not TEST_CUDA, "CUDA not available") -@unittest.skipIf(not SM100_OR_LATER, "requires SM100+") -@unittest.skipIf(not _OVERRIDE_APPLIED, "override not applied") -class TestFallback(unittest.TestCase): - """Verify that unsupported inputs fall back to the native CUDA kernel.""" - - def test_float64_rmsnorm_falls_back(self): - """float64 is not supported by oink — should fall back gracefully.""" - x = torch.randn(4, 32, dtype=torch.float64, device="cuda") - w = torch.randn(32, dtype=torch.float64, device="cuda") - y, rstd = torch.ops.aten._fused_rms_norm(x, [32], w, EPS) - self.assertEqual(y.shape, x.shape) - self.assertEqual(y.dtype, torch.float64) - - def test_float64_layernorm_falls_back(self): - x = torch.randn(4, 32, dtype=torch.float64, device="cuda") - w = torch.randn(32, dtype=torch.float64, device="cuda") - b = torch.randn(32, dtype=torch.float64, device="cuda") - out, mean, rstd = torch.ops.aten.native_layer_norm(x, [32], w, b, EPS) - self.assertEqual(out.shape, x.shape) - self.assertEqual(out.dtype, torch.float64) - - def test_float64_softmax_falls_back(self): - x = torch.randn(4, 32, dtype=torch.float64, device="cuda") - y = torch.ops.aten._softmax(x, -1, False) - self.assertEqual(y.shape, x.shape) - self.assertEqual(y.dtype, torch.float64) - - def test_non_last_dim_softmax_falls_back(self): - """Softmax on dim=0 should fall back (oink only handles last dim).""" - x = torch.randn(4, 32, dtype=torch.float16, device="cuda") - y = torch.ops.aten._softmax(x, 0, False) - self.assertEqual(y.shape, x.shape) - # Verify correctness: softmax on dim=0 - y_ref = _orig_softmax(x, 0, False) - torch.testing.assert_close(y, y_ref, atol=1e-3, rtol=0) - - -if __name__ == "__main__": - unittest.main() +@requires_cuda +@requires_sm100 +@requires_override +def test_float64_rmsnorm_falls_back(): + """float64 is not supported by oink — should fall back gracefully.""" + x = torch.randn(4, 32, dtype=torch.float64, device="cuda") + w = torch.randn(32, dtype=torch.float64, device="cuda") + y, rstd = torch.ops.aten._fused_rms_norm(x, [32], w, EPS) + assert y.shape == x.shape + assert y.dtype == torch.float64 + + +@requires_cuda +@requires_sm100 +@requires_override +def test_float64_layernorm_falls_back(): + x = torch.randn(4, 32, dtype=torch.float64, device="cuda") + w = torch.randn(32, dtype=torch.float64, device="cuda") + b = torch.randn(32, dtype=torch.float64, device="cuda") + out, mean, rstd = torch.ops.aten.native_layer_norm(x, [32], w, b, EPS) + assert out.shape == x.shape + assert out.dtype == torch.float64 + + +@requires_cuda +@requires_sm100 +@requires_override +def test_float64_softmax_falls_back(): + x = torch.randn(4, 32, dtype=torch.float64, device="cuda") + y = torch.ops.aten._softmax(x, -1, False) + assert y.shape == x.shape + assert y.dtype == torch.float64 + + +@requires_cuda +@requires_sm100 +@requires_override +def test_non_last_dim_softmax_falls_back(): + """Softmax on dim=0 should fall back (oink only handles last dim).""" + x = torch.randn(4, 32, dtype=torch.float16, device="cuda") + y = torch.ops.aten._softmax(x, 0, False) + assert y.shape == x.shape + y_ref = _orig_softmax(x, 0, False) + torch.testing.assert_close(y, y_ref, atol=1e-3, rtol=0) From 2560dc3f0a9c401580aa5f4ea0845274cfa2a44f Mon Sep 17 00:00:00 2001 From: Jiannan Wang Date: Mon, 13 Apr 2026 12:38:35 -0700 Subject: [PATCH 3/6] Refine aten overrides: add env-gated tracing, remove debug prints MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add OINK_TRACE=1 env var toggle for kernel call tracing. When off (default), each call site pays only a single `if _TRACE_ENABLED` boolean check — zero overhead. When on, prints first-invocation markers and tracks call counts per op. - Remove hardcoded debug prints that were used for initial validation. - Fix stray character from linter pass. --- oink/src/kernelagent_oink/__init__.py | 1 - oink/src/kernelagent_oink/aten_override.py | 71 +++++++++++----------- 2 files changed, 36 insertions(+), 36 deletions(-) diff --git a/oink/src/kernelagent_oink/__init__.py b/oink/src/kernelagent_oink/__init__.py index ee8b18f..5449481 100644 --- a/oink/src/kernelagent_oink/__init__.py +++ b/oink/src/kernelagent_oink/__init__.py @@ -157,7 +157,6 @@ def register_all_kernels(*, force: bool = False) -> None: register(force=force) if not _OPS_REGISTERED: - # register() decided to bail (missing deps, no CUDA, env gate, etc.). return # Step 2: override aten ops on CUDA. diff --git a/oink/src/kernelagent_oink/aten_override.py b/oink/src/kernelagent_oink/aten_override.py index c2b3ecd..5573316 100644 --- a/oink/src/kernelagent_oink/aten_override.py +++ b/oink/src/kernelagent_oink/aten_override.py @@ -48,6 +48,7 @@ import importlib import logging import math +import os import threading from typing import List, Optional, Tuple @@ -55,9 +56,6 @@ logger = logging.getLogger(__name__) -# --------------------------------------------------------------------------- -# Lazy kernel module imports -# --------------------------------------------------------------------------- _MOD_CACHE: dict[str, object] = {} _MOD_LOCK = threading.Lock() @@ -76,10 +74,6 @@ def _get_mod(name: str): return _MOD_CACHE[name] -# --------------------------------------------------------------------------- -# Device capability helpers -# --------------------------------------------------------------------------- - @functools.cache def _get_device_sm(device: torch.device) -> int: @@ -89,6 +83,27 @@ def _get_device_sm(device: torch.device) -> int: _SUPPORTED_DTYPES = (torch.float16, torch.bfloat16, torch.float32) +# --------------------------------------------------------------------------- +# Optional tracing — zero overhead when disabled. +# Enable with OINK_TRACE=1 environment variable. +# --------------------------------------------------------------------------- + +_TRACE_ENABLED: bool = os.environ.get("OINK_TRACE", "0").strip() in ("1", "true") +_call_counts: dict[str, int] = {} + + +def _trace_call(op_name: str) -> None: + """Record and print a marker on the first invocation of each override. + + Only active when ``OINK_TRACE=1`` is set in the environment. When + disabled, each call site pays only a single boolean check + (``if _TRACE_ENABLED``) — effectively zero overhead. + """ + n = _call_counts.get(op_name, 0) + 1 + _call_counts[op_name] = n + if n == 1: + print(f"[OINK] {op_name} override called (first invocation)", flush=True) + def _is_supported(t: torch.Tensor) -> bool: """True when Oink's SM100 kernel can handle this tensor.""" @@ -98,11 +113,6 @@ def _is_supported(t: torch.Tensor) -> bool: and _get_device_sm(t.device) >= 100 ) - -# --------------------------------------------------------------------------- -# Fallback kernel capture -# --------------------------------------------------------------------------- - _fallbacks: dict[str, object] = {} @@ -128,11 +138,6 @@ def _call_fallback(op_name: str, *args): ) -# --------------------------------------------------------------------------- -# Reshape / stride helpers -# --------------------------------------------------------------------------- - - def _can_view_as_2d(x: torch.Tensor) -> bool: """Return True if ``x.view(-1, x.shape[-1])`` is viewable (no copy). @@ -162,9 +167,9 @@ def _is_oink_stride_compatible_2d(x_2d: torch.Tensor) -> bool: if x_2d.stride(1) != 1: return False if x_2d.dtype in (torch.float16, torch.bfloat16): - divby = 16 # 256 bits / 16 bits = 16 elements + divby = 16 elif x_2d.dtype == torch.float32: - divby = 8 # 256 bits / 32 bits = 8 elements + divby = 8 else: return False return (x_2d.stride(0) % divby) == 0 @@ -206,9 +211,6 @@ def _stat_shape(input_shape, normalized_shape_len: int) -> list[int]: return list(input_shape[:-normalized_shape_len]) + [1] * normalized_shape_len -# ========================================================================= -# RMSNorm -# ========================================================================= def _oink_fused_rms_norm( @@ -235,6 +237,8 @@ def _oink_fused_rms_norm( "_fused_rms_norm", input, normalized_shape, weight, eps ) + if _TRACE_ENABLED: + _trace_call("_fused_rms_norm") mod = _get_mod("rmsnorm") y, rstd, _ = mod.rmsnorm_forward( x, weight=weight, bias=None, residual=None, eps=eps, store_rstd=True, @@ -266,6 +270,8 @@ def _oink_fused_rms_norm_backward( dout = _reshape_2d(grad_out, M, N) rstd_flat = _flatten_1d(rstd, M) + if _TRACE_ENABLED: + _trace_call("_fused_rms_norm_backward") mod = _get_mod("rmsnorm") dx, dw, _dbias, _dres = mod.rmsnorm_backward( x, weight, dout, rstd_flat, @@ -282,11 +288,6 @@ def _oink_fused_rms_norm_backward( return dx, dw -# ========================================================================= -# LayerNorm -# ========================================================================= - - def _oink_native_layer_norm( input: torch.Tensor, normalized_shape: List[int], @@ -309,6 +310,8 @@ def _oink_native_layer_norm( "native_layer_norm", input, normalized_shape, weight, bias, eps ) + if _TRACE_ENABLED: + _trace_call("native_layer_norm") mod = _get_mod("layernorm") out, rstd, mean = mod.layernorm( x, weight, bias=bias, eps=eps, return_rstd=True, return_mean=True, @@ -346,6 +349,8 @@ def _oink_native_layer_norm_backward( mean_flat = _flatten_1d(mean, M) rstd_flat = _flatten_1d(rstd, M) + if _TRACE_ENABLED: + _trace_call("native_layer_norm_backward") mod = _get_mod("layernorm") dx, dw, db = mod.layernorm_backward( dout, x, weight, rstd_flat, mean_flat, bias=bias, @@ -363,11 +368,6 @@ def _oink_native_layer_norm_backward( return dx, dw, db -# ========================================================================= -# Softmax -# ========================================================================= - - def _oink_softmax( self: torch.Tensor, dim: int, @@ -389,6 +389,8 @@ def _oink_softmax( if x is None: return _call_fallback("_softmax", self, dim, half_to_float) + if _TRACE_ENABLED: + _trace_call("_softmax") mod = _get_mod("softmax") y = mod.softmax_forward(x) @@ -420,15 +422,14 @@ def _oink_softmax_backward( dy = _reshape_2d(grad_output, M, N) y = _reshape_2d(output, M, N) + if _TRACE_ENABLED: + _trace_call("_softmax_backward_data") mod = _get_mod("softmax") dx = mod.softmax_backward(dy, y) return dx.reshape(input_shape) -# ========================================================================= -# Registration -# ========================================================================= _ATEN_LIB: torch.library.Library | None = None From 6489635212183caf431c741402e3cc7392cc8978 Mon Sep 17 00:00:00 2001 From: Jiannan Wang Date: Mon, 13 Apr 2026 12:38:35 -0700 Subject: [PATCH 4/6] Refine aten overrides: add env-gated tracing, remove debug prints MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add OINK_TRACE=1 env var toggle for kernel call tracing. When off (default), each call site pays only a single `if _TRACE_ENABLED` boolean check — zero overhead. When on, prints first-invocation markers and tracks call counts per op. - Remove hardcoded debug prints that were used for initial validation. - Fix stray character from linter pass. --- oink/src/kernelagent_oink/__init__.py | 1 - oink/src/kernelagent_oink/aten_override.py | 71 +++++++++++----------- oink/tests/test_aten_override.py | 1 - 3 files changed, 36 insertions(+), 37 deletions(-) diff --git a/oink/src/kernelagent_oink/__init__.py b/oink/src/kernelagent_oink/__init__.py index ee8b18f..5449481 100644 --- a/oink/src/kernelagent_oink/__init__.py +++ b/oink/src/kernelagent_oink/__init__.py @@ -157,7 +157,6 @@ def register_all_kernels(*, force: bool = False) -> None: register(force=force) if not _OPS_REGISTERED: - # register() decided to bail (missing deps, no CUDA, env gate, etc.). return # Step 2: override aten ops on CUDA. diff --git a/oink/src/kernelagent_oink/aten_override.py b/oink/src/kernelagent_oink/aten_override.py index c2b3ecd..5573316 100644 --- a/oink/src/kernelagent_oink/aten_override.py +++ b/oink/src/kernelagent_oink/aten_override.py @@ -48,6 +48,7 @@ import importlib import logging import math +import os import threading from typing import List, Optional, Tuple @@ -55,9 +56,6 @@ logger = logging.getLogger(__name__) -# --------------------------------------------------------------------------- -# Lazy kernel module imports -# --------------------------------------------------------------------------- _MOD_CACHE: dict[str, object] = {} _MOD_LOCK = threading.Lock() @@ -76,10 +74,6 @@ def _get_mod(name: str): return _MOD_CACHE[name] -# --------------------------------------------------------------------------- -# Device capability helpers -# --------------------------------------------------------------------------- - @functools.cache def _get_device_sm(device: torch.device) -> int: @@ -89,6 +83,27 @@ def _get_device_sm(device: torch.device) -> int: _SUPPORTED_DTYPES = (torch.float16, torch.bfloat16, torch.float32) +# --------------------------------------------------------------------------- +# Optional tracing — zero overhead when disabled. +# Enable with OINK_TRACE=1 environment variable. +# --------------------------------------------------------------------------- + +_TRACE_ENABLED: bool = os.environ.get("OINK_TRACE", "0").strip() in ("1", "true") +_call_counts: dict[str, int] = {} + + +def _trace_call(op_name: str) -> None: + """Record and print a marker on the first invocation of each override. + + Only active when ``OINK_TRACE=1`` is set in the environment. When + disabled, each call site pays only a single boolean check + (``if _TRACE_ENABLED``) — effectively zero overhead. + """ + n = _call_counts.get(op_name, 0) + 1 + _call_counts[op_name] = n + if n == 1: + print(f"[OINK] {op_name} override called (first invocation)", flush=True) + def _is_supported(t: torch.Tensor) -> bool: """True when Oink's SM100 kernel can handle this tensor.""" @@ -98,11 +113,6 @@ def _is_supported(t: torch.Tensor) -> bool: and _get_device_sm(t.device) >= 100 ) - -# --------------------------------------------------------------------------- -# Fallback kernel capture -# --------------------------------------------------------------------------- - _fallbacks: dict[str, object] = {} @@ -128,11 +138,6 @@ def _call_fallback(op_name: str, *args): ) -# --------------------------------------------------------------------------- -# Reshape / stride helpers -# --------------------------------------------------------------------------- - - def _can_view_as_2d(x: torch.Tensor) -> bool: """Return True if ``x.view(-1, x.shape[-1])`` is viewable (no copy). @@ -162,9 +167,9 @@ def _is_oink_stride_compatible_2d(x_2d: torch.Tensor) -> bool: if x_2d.stride(1) != 1: return False if x_2d.dtype in (torch.float16, torch.bfloat16): - divby = 16 # 256 bits / 16 bits = 16 elements + divby = 16 elif x_2d.dtype == torch.float32: - divby = 8 # 256 bits / 32 bits = 8 elements + divby = 8 else: return False return (x_2d.stride(0) % divby) == 0 @@ -206,9 +211,6 @@ def _stat_shape(input_shape, normalized_shape_len: int) -> list[int]: return list(input_shape[:-normalized_shape_len]) + [1] * normalized_shape_len -# ========================================================================= -# RMSNorm -# ========================================================================= def _oink_fused_rms_norm( @@ -235,6 +237,8 @@ def _oink_fused_rms_norm( "_fused_rms_norm", input, normalized_shape, weight, eps ) + if _TRACE_ENABLED: + _trace_call("_fused_rms_norm") mod = _get_mod("rmsnorm") y, rstd, _ = mod.rmsnorm_forward( x, weight=weight, bias=None, residual=None, eps=eps, store_rstd=True, @@ -266,6 +270,8 @@ def _oink_fused_rms_norm_backward( dout = _reshape_2d(grad_out, M, N) rstd_flat = _flatten_1d(rstd, M) + if _TRACE_ENABLED: + _trace_call("_fused_rms_norm_backward") mod = _get_mod("rmsnorm") dx, dw, _dbias, _dres = mod.rmsnorm_backward( x, weight, dout, rstd_flat, @@ -282,11 +288,6 @@ def _oink_fused_rms_norm_backward( return dx, dw -# ========================================================================= -# LayerNorm -# ========================================================================= - - def _oink_native_layer_norm( input: torch.Tensor, normalized_shape: List[int], @@ -309,6 +310,8 @@ def _oink_native_layer_norm( "native_layer_norm", input, normalized_shape, weight, bias, eps ) + if _TRACE_ENABLED: + _trace_call("native_layer_norm") mod = _get_mod("layernorm") out, rstd, mean = mod.layernorm( x, weight, bias=bias, eps=eps, return_rstd=True, return_mean=True, @@ -346,6 +349,8 @@ def _oink_native_layer_norm_backward( mean_flat = _flatten_1d(mean, M) rstd_flat = _flatten_1d(rstd, M) + if _TRACE_ENABLED: + _trace_call("native_layer_norm_backward") mod = _get_mod("layernorm") dx, dw, db = mod.layernorm_backward( dout, x, weight, rstd_flat, mean_flat, bias=bias, @@ -363,11 +368,6 @@ def _oink_native_layer_norm_backward( return dx, dw, db -# ========================================================================= -# Softmax -# ========================================================================= - - def _oink_softmax( self: torch.Tensor, dim: int, @@ -389,6 +389,8 @@ def _oink_softmax( if x is None: return _call_fallback("_softmax", self, dim, half_to_float) + if _TRACE_ENABLED: + _trace_call("_softmax") mod = _get_mod("softmax") y = mod.softmax_forward(x) @@ -420,15 +422,14 @@ def _oink_softmax_backward( dy = _reshape_2d(grad_output, M, N) y = _reshape_2d(output, M, N) + if _TRACE_ENABLED: + _trace_call("_softmax_backward_data") mod = _get_mod("softmax") dx = mod.softmax_backward(dy, y) return dx.reshape(input_shape) -# ========================================================================= -# Registration -# ========================================================================= _ATEN_LIB: torch.library.Library | None = None diff --git a/oink/tests/test_aten_override.py b/oink/tests/test_aten_override.py index 3efa93c..6ecb635 100644 --- a/oink/tests/test_aten_override.py +++ b/oink/tests/test_aten_override.py @@ -24,7 +24,6 @@ from __future__ import annotations -import math import types import pytest From 22e4366f5bc43728eb599e7e66aa9d8bf2fd4784 Mon Sep 17 00:00:00 2001 From: Jiannan Wang Date: Tue, 14 Apr 2026 12:10:07 -0700 Subject: [PATCH 5/6] Add TVM-FFI compiled _compile_rmsnorm_fwd API and simplify aten override MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add _compile_rmsnorm_fwd() to rmsnorm.py: pre-compiles the RMSNorm kernel with --enable-tvm-ffi so the returned callable accepts torch tensors directly via DLPack at C++ level. Eliminates per-call Python pointer construction overhead (rt.make_ptr, Int32, Float32). - Simplify aten_override.py to rmsnorm-only (fwd + bwd), following the quack PR pattern: with_keyset=True, fallback via call_boxed, functools.partial for dispatch_keys binding. - Add benchmark_rmsnorm_oink_vs_aten.py for kernel-level comparison using do_bench_triton (same infrastructure as benchmark_rmsnorm_sm100). Results on GB200 (bf16, store_rstd=True): (4096, 4096): oink 0.018ms vs aten 0.023ms → 1.29x (4096, 8192): oink 0.028ms vs aten 0.038ms → 1.32x (16384, 8192): oink 0.086ms vs aten 0.103ms → 1.19x (65536, 8192): oink 0.335ms vs aten 0.375ms → 1.12x --- .../benchmark_rmsnorm_oink_vs_aten.py | 277 +++++++++++ oink/src/kernelagent_oink/aten_override.py | 445 ++++-------------- .../src/kernelagent_oink/blackwell/rmsnorm.py | 108 +++++ 3 files changed, 472 insertions(+), 358 deletions(-) create mode 100644 oink/benchmarks/benchmark/benchmark_rmsnorm_oink_vs_aten.py diff --git a/oink/benchmarks/benchmark/benchmark_rmsnorm_oink_vs_aten.py b/oink/benchmarks/benchmark/benchmark_rmsnorm_oink_vs_aten.py new file mode 100644 index 0000000..f29375e --- /dev/null +++ b/oink/benchmarks/benchmark/benchmark_rmsnorm_oink_vs_aten.py @@ -0,0 +1,277 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Benchmark Oink CuTeDSL RMSNorm vs PyTorch Aten RMSNorm (aten::_fused_rms_norm). + +Based on benchmark_rmsnorm_sm100.py — replaces quack with torch's native aten +kernel so that oink is compared directly against PyTorch's built-in CUDA +implementation. + +Both kernels are called at the same level: direct function call, no aten +override dispatch layer. This isolates kernel performance from Python/dispatch +overhead. +""" + +from __future__ import annotations + +import argparse +import os +from typing import Any, Dict, List, Tuple + +import torch + +os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True") + +from bench_utils import ( # noqa: E402 + ErrorStatsAccumulator, + collect_device_meta, + detect_hbm_peak_gbps, + do_bench_triton, + ensure_blackwell_arch_env, + error_stats_to_row, + ensure_oink_src_on_path, + iter_row_blocks, + parse_configs, + parse_dtype, + quack_suite_configs, + write_csv, + write_json, +) + +ensure_blackwell_arch_env() +ensure_oink_src_on_path() + +from kernelagent_oink.blackwell import rmsnorm as oink_rmsnorm # noqa: E402 + +# PyTorch aten _fused_rms_norm — called directly to avoid any override layer. +_aten_fused_rms_norm = torch.ops.aten._fused_rms_norm + +_VERIFY_TOL_Y = { + torch.float32: dict(atol=1e-4, rtol=1e-3), + torch.float16: dict(atol=1e-2, rtol=1e-3), + torch.bfloat16: dict(atol=1e-1, rtol=1e-2), +} + +_VERIFY_TOL_RSTD = { + torch.float32: dict(atol=1e-5, rtol=1e-5), + torch.float16: dict(atol=1e-3, rtol=1e-3), + torch.bfloat16: dict(atol=1e-3, rtol=1e-3), +} + + +def bytes_io_model_fwd( + M: int, N: int, dtype: torch.dtype, *, weight_dtype: torch.dtype = torch.float32 +) -> int: + elem = torch.tensor(0, dtype=dtype).element_size() + w_elem = torch.tensor(0, dtype=weight_dtype).element_size() + total = 2 * M * N * elem # read x + write y + total += N * w_elem # read weight + return int(total) + + +def dsv3_configs() -> List[Tuple[int, int]]: + Ms = [4096, 16384, 65536] + Ns = [6144, 7168, 8192] + return [(m, n) for m in Ms for n in Ns] + + +def _verify_parity( + x: torch.Tensor, + w: torch.Tensor, + *, + eps: float, + store_rstd: bool, +) -> dict[str, object]: + tol_y = _VERIFY_TOL_Y[x.dtype] + tol_rstd = _VERIFY_TOL_RSTD[x.dtype] + ref_block_rows = 4096 + M = int(x.shape[0]) + N = int(x.shape[1]) + + y_acc_ours = ErrorStatsAccumulator(total_elems=M * N) + y_acc_aten = ErrorStatsAccumulator(total_elems=M * N) + + with torch.no_grad(): + y_o, rstd_o, res_o = oink_rmsnorm.rmsnorm_forward( + x, weight=w, bias=None, residual=None, eps=eps, store_rstd=store_rstd, + ) + y_a, rstd_a = _aten_fused_rms_norm(x, [N], w, eps) + + # Pure-PyTorch reference (float32 accumulation), chunked. + w_f32 = w.float() + rstd_ref = torch.empty((M,), device=x.device, dtype=torch.float32) + for start, end in iter_row_blocks(M, ref_block_rows): + x_f32 = x[start:end].float() + rstd_blk = torch.rsqrt(x_f32.square().mean(dim=-1) + eps) + rstd_ref[start:end] = rstd_blk + + y_ref_blk_f32 = (x_f32 * rstd_blk.unsqueeze(1)) * w_f32 + y_ref_blk = y_ref_blk_f32.to(x.dtype) + torch.testing.assert_close(y_o[start:end], y_ref_blk, **tol_y) + y_acc_ours.update(y_o[start:end], y_ref_blk) + torch.testing.assert_close(y_a[start:end], y_ref_blk, **tol_y) + y_acc_aten.update(y_a[start:end], y_ref_blk) + + stats: dict[str, object] = {} + stats.update(error_stats_to_row("ours_err_y", y_acc_ours.finalize())) + stats.update(error_stats_to_row("aten_err_y", y_acc_aten.finalize())) + + if store_rstd: + assert rstd_o is not None + torch.testing.assert_close(rstd_o, rstd_ref, **tol_rstd) + rstd_acc_ours = ErrorStatsAccumulator( + total_elems=int(rstd_ref.numel()), + p99_target_samples=int(rstd_ref.numel()), + ) + rstd_acc_ours.update(rstd_o, rstd_ref) + stats.update(error_stats_to_row("ours_err_rstd", rstd_acc_ours.finalize())) + + assert res_o is None + return stats + + +def bench_single( + M: int, + N: int, + dtype: torch.dtype, + *, + weight_dtype: torch.dtype, + eps: float, + warmup_ms: int, + iters_ms: int, + verify: bool, + store_rstd: bool, +) -> Tuple[Tuple[float, float], Tuple[float, float], dict[str, object]]: + device = torch.device("cuda") + x = torch.randn(M, N, device=device, dtype=dtype) + w = torch.randn(N, device=device, dtype=weight_dtype) + + stats: dict[str, object] = {} + if verify: + stats = _verify_parity(x, w, eps=eps, store_rstd=store_rstd) + + bytes_io = bytes_io_model_fwd(M, N, dtype, weight_dtype=w.dtype) + + # Oink: call rmsnorm_forward directly (same as benchmark_rmsnorm_sm100.py). + def fn_oink(): + return oink_rmsnorm.rmsnorm_forward( + x, weight=w, bias=None, residual=None, eps=eps, store_rstd=store_rstd, + ) + + ms_oink = do_bench_triton(fn_oink, warmup_ms=warmup_ms, rep_ms=iters_ms) + gbps_oink = bytes_io / (ms_oink * 1e-3) / 1e9 + + # Aten: call _fused_rms_norm directly (no Python override layer). + def fn_aten(): + return _aten_fused_rms_norm(x, [N], w, eps) + + ms_aten = do_bench_triton(fn_aten, warmup_ms=warmup_ms, rep_ms=iters_ms) + gbps_aten = bytes_io / (ms_aten * 1e-3) / 1e9 + + return (ms_oink, gbps_oink), (ms_aten, gbps_aten), stats + + +def main() -> None: + if not torch.cuda.is_available(): + raise SystemExit("CUDA not available") + + torch.cuda.set_device(0) + device = torch.device("cuda") + props = torch.cuda.get_device_properties(device) + sm = props.major * 10 + props.minor + print(f"Running on {torch.cuda.get_device_name(device)} (SM{sm})") + + p = argparse.ArgumentParser() + p.add_argument("--dtype", type=str, default="bf16", choices=["fp16", "bf16", "fp32"]) + p.add_argument( + "--weight-dtype", type=str, default="fp32", + choices=["same", "fp16", "bf16", "fp32"], + ) + p.add_argument("--eps", type=float, default=1e-6) + p.add_argument("--store-rstd", action="store_true") + p.add_argument("--iters", type=int, default=100) + p.add_argument("--warmup-ms", type=int, default=25) + p.add_argument("--csv", type=str, default=None) + p.add_argument("--json", type=str, default=None) + p.add_argument("--configs", type=str, default="1024x4096,8192x4096") + p.add_argument("--quack-suite", action="store_true") + p.add_argument("--dsv3", action="store_true") + p.add_argument("--skip-verify", action="store_true") + args = p.parse_args() + + dtype = parse_dtype(args.dtype) + weight_dtype = dtype if args.weight_dtype == "same" else parse_dtype(args.weight_dtype) + eps = float(args.eps) + + if args.quack_suite: + cfgs = [(bs * sl, hidden) for (bs, sl, hidden) in quack_suite_configs()] + elif args.dsv3: + cfgs = dsv3_configs() + else: + cfgs = parse_configs(args.configs) + + hbm_peak = detect_hbm_peak_gbps(device) + meta = collect_device_meta(device) + + rows_out: List[Dict[str, Any]] = [] + for M, N in cfgs: + print(f"bench M={M:<8d} N={N:<6d} dtype={args.dtype} ...", flush=True) + (ms_oink, gbps_oink), (ms_aten, gbps_aten), stats = bench_single( + M=M, N=N, dtype=dtype, weight_dtype=weight_dtype, eps=eps, + warmup_ms=int(args.warmup_ms), iters_ms=int(args.iters), + verify=not args.skip_verify, store_rstd=bool(args.store_rstd), + ) + row: Dict[str, Any] = { + "M": M, "N": N, "dtype": args.dtype, + "weight_dtype": args.weight_dtype, "eps": eps, + "store_rstd": bool(args.store_rstd), + "oink_ms": ms_oink, "oink_gbps": gbps_oink, + "oink_tbps": gbps_oink / 1000.0, + "oink_hbm_frac": gbps_oink / hbm_peak, + "aten_ms": ms_aten, "aten_gbps": gbps_aten, + "aten_tbps": gbps_aten / 1000.0, + "speedup_vs_aten": ms_aten / ms_oink, + } + row.update(stats) + rows_out.append(row) + + if args.csv is not None: + write_csv(args.csv, rows_out) + if args.json is not None: + write_json(args.json, meta, rows_out, extra={ + "method": "triton.testing.do_bench(mean)", + "warmup_ms": int(args.warmup_ms), "rep_ms": int(args.iters), + "io_model_bytes": "(2*M*N)*elem_size + N*weight_elem_size", + "store_rstd": bool(args.store_rstd), + "weight_dtype": str(args.weight_dtype), + }) + + # Compact summary table. + headers = ["M", "N", "oink_ms", "oink_tbps", "aten_ms", "aten_tbps", "speedup_vs_aten"] + print("\nSummary:") + print(" ".join(h.rjust(16) for h in headers)) + for r in rows_out: + parts: List[str] = [] + for h in headers: + v = r.get(h) + if isinstance(v, float): + parts.append(f"{v:16.4f}") + else: + parts.append(f"{str(v):>16}") + print(" ".join(parts)) + + +if __name__ == "__main__": + main() diff --git a/oink/src/kernelagent_oink/aten_override.py b/oink/src/kernelagent_oink/aten_override.py index 5573316..cfcdc95 100644 --- a/oink/src/kernelagent_oink/aten_override.py +++ b/oink/src/kernelagent_oink/aten_override.py @@ -13,79 +13,33 @@ # limitations under the License. """ -Override PyTorch aten operators with Oink's Blackwell CuTeDSL kernels. +Override ``aten::_fused_rms_norm`` with Oink's Blackwell CuTeDSL RMSNorm. -Patches the following aten ops at the CUDA dispatch key: - -- ``aten::_fused_rms_norm`` → :func:`rmsnorm_forward` -- ``aten::_fused_rms_norm_backward`` → :func:`rmsnorm_backward` -- ``aten::native_layer_norm`` → :func:`layernorm` -- ``aten::native_layer_norm_backward`` → :func:`layernorm_backward` -- ``aten::_softmax`` → :func:`softmax_forward` -- ``aten::_softmax_backward_data`` → :func:`softmax_backward` - -All standard PyTorch APIs (``F.rms_norm``, ``F.layer_norm``, ``F.softmax``, -etc.) transparently route through the Oink kernels on SM100+ CUDA devices -after calling :func:`override_all_aten_kernels`. - -Each override captures the original CUDA kernel via -``torch.library.get_kernel`` *before* patching, so unsupported inputs -(wrong dtype, older GPU) fall back to PyTorch's native implementation. +Follows the quack PR pattern (``torch/_native/ops/norm/rmsnorm_impl.py``): +registers with ``with_keyset=True`` so the impl receives a +``DispatchKeySet`` and can call the captured fallback via ``call_boxed``. Usage:: - from kernelagent_oink.aten_override import override_all_aten_kernels - - override_all_aten_kernels() - y = torch.nn.functional.rms_norm(x, [N], weight, eps) # uses Oink - - restore_all_aten_kernels() # restores PyTorch + import kernelagent_oink + kernelagent_oink.register_all_kernels(force=True) + y = torch.nn.functional.rms_norm(x, [N], weight, eps) # uses Oink """ from __future__ import annotations import functools -import importlib import logging import math import os -import threading from typing import List, Optional, Tuple import torch logger = logging.getLogger(__name__) - -_MOD_CACHE: dict[str, object] = {} -_MOD_LOCK = threading.Lock() - - -def _get_mod(name: str): - """Thread-safe lazy import of ``kernelagent_oink.blackwell.``.""" - cached = _MOD_CACHE.get(name) - if cached is not None: - return cached - with _MOD_LOCK: - if name not in _MOD_CACHE: - _MOD_CACHE[name] = importlib.import_module( - f"kernelagent_oink.blackwell.{name}" - ) - return _MOD_CACHE[name] - - - -@functools.cache -def _get_device_sm(device: torch.device) -> int: - major, minor = torch.cuda.get_device_capability(device) - return 10 * major + minor - - -_SUPPORTED_DTYPES = (torch.float16, torch.bfloat16, torch.float32) - # --------------------------------------------------------------------------- -# Optional tracing — zero overhead when disabled. -# Enable with OINK_TRACE=1 environment variable. +# Optional tracing — zero overhead when disabled (OINK_TRACE=1). # --------------------------------------------------------------------------- _TRACE_ENABLED: bool = os.environ.get("OINK_TRACE", "0").strip() in ("1", "true") @@ -93,397 +47,172 @@ def _get_device_sm(device: torch.device) -> int: def _trace_call(op_name: str) -> None: - """Record and print a marker on the first invocation of each override. - - Only active when ``OINK_TRACE=1`` is set in the environment. When - disabled, each call site pays only a single boolean check - (``if _TRACE_ENABLED``) — effectively zero overhead. - """ n = _call_counts.get(op_name, 0) + 1 _call_counts[op_name] = n if n == 1: print(f"[OINK] {op_name} override called (first invocation)", flush=True) -def _is_supported(t: torch.Tensor) -> bool: - """True when Oink's SM100 kernel can handle this tensor.""" - return ( - t.is_cuda - and t.dtype in _SUPPORTED_DTYPES - and _get_device_sm(t.device) >= 100 - ) +# --------------------------------------------------------------------------- +# Device / dtype support check +# --------------------------------------------------------------------------- -_fallbacks: dict[str, object] = {} +@functools.cache +def _get_device_major(device: torch.device) -> int: + major, _ = torch.cuda.get_device_capability(device) + return major -def _capture_fallback(op_name: str, dispatch_key: str = "CUDA") -> None: - """Snapshot the current CUDA kernel for ``aten::`` before we - overwrite it. Must be called *before* ``lib.impl``.""" - if op_name in _fallbacks: - return - try: - _fallbacks[op_name] = torch.library.get_kernel( - f"aten::{op_name}", dispatch_key - ) - except Exception: - _fallbacks[op_name] = None +def _is_supported(input: torch.Tensor) -> bool: + return input.dtype in ( + torch.float16, + torch.bfloat16, + torch.float32, + ) and _get_device_major(input.device) in (9, 10) -def _call_fallback(op_name: str, *args): - fb = _fallbacks.get(op_name) - if fb is not None: - return fb(*args) - raise RuntimeError( - f"Oink: no fallback captured for aten::{op_name} and input is unsupported" - ) +# ========================================================================= +# RMSNorm forward +# ========================================================================= -def _can_view_as_2d(x: torch.Tensor) -> bool: - """Return True if ``x.view(-1, x.shape[-1])`` is viewable (no copy). - For a view(-1, N) to be valid, all leading dims must be contiguous with - respect to each other (size-1 dims are ignored). - """ - if x.dim() < 2: - return False - if x.dim() == 2: - return True - for dim in range(x.dim() - 1): - if x.size(dim + 1) != 1 and x.stride(dim) != x.stride(dim + 1) * x.size( - dim + 1 - ): - return False - return True - - -def _is_oink_stride_compatible_2d(x_2d: torch.Tensor) -> bool: - """Return True if *x_2d* meets Oink's pointer-path stride constraints. - - Requires stride(1) == 1 (row-major last dim) and stride(0) divisible by - the vectorization granularity (256 bits). - """ - if x_2d.dim() != 2: - return False - if x_2d.stride(1) != 1: - return False - if x_2d.dtype in (torch.float16, torch.bfloat16): - divby = 16 - elif x_2d.dtype == torch.float32: - divby = 8 - else: - return False - return (x_2d.stride(0) % divby) == 0 - - -def _reshape_2d(t: torch.Tensor, M: int, N: int) -> torch.Tensor: - if t.ndim == 2 and t.shape == (M, N) and t.is_contiguous(): - return t - return t.reshape(M, N).contiguous() - - -def _reshape_2d_checked(t: torch.Tensor, M: int, N: int) -> torch.Tensor | None: - """Reshape to 2D and return None if the result doesn't meet Oink's stride - constraints. Callers should fall back to the native kernel on None.""" - if not _can_view_as_2d(t): - return None - x_2d = t.view(-1, N) if t.dim() > 2 else t - if x_2d.shape != (M, N): - x_2d = t.reshape(M, N) - if not x_2d.is_contiguous() and not _is_oink_stride_compatible_2d(x_2d): - # contiguous() always produces stride-compatible layout. - x_2d = x_2d.contiguous() - if not _is_oink_stride_compatible_2d(x_2d): - return None - return x_2d - - -def _flatten_1d(t: torch.Tensor, M: int) -> torch.Tensor: - if t.ndim == 1 and t.shape[0] == M: - return t - if t.is_contiguous() and t.numel() == M: - return t.detach().view(M) - return t.reshape(M).contiguous() - - -def _stat_shape(input_shape, normalized_shape_len: int) -> list[int]: - """Shape for rstd / mean: ``[*batch_dims, 1, 1, ...]`` with - ``normalized_shape_len`` trailing ones.""" - return list(input_shape[:-normalized_shape_len]) + [1] * normalized_shape_len - - - - -def _oink_fused_rms_norm( +def _fused_rms_norm_impl( + dispatch_keys: torch.DispatchKeySet, input: torch.Tensor, normalized_shape: List[int], weight: Optional[torch.Tensor], eps: Optional[float], + *, + fallback_kernel, ) -> Tuple[torch.Tensor, torch.Tensor]: if not _is_supported(input): - return _call_fallback( - "_fused_rms_norm", input, normalized_shape, weight, eps + return fallback_kernel.call_boxed( + dispatch_keys, input, normalized_shape, weight, eps ) if eps is None: eps = 1e-6 - input_shape = input.shape + if _TRACE_ENABLED: + _trace_call("_fused_rms_norm") + + orig_shape = input.shape N = math.prod(normalized_shape) M = input.numel() // N - x = _reshape_2d_checked(input, M, N) - if x is None: - return _call_fallback( - "_fused_rms_norm", input, normalized_shape, weight, eps - ) + x = input.reshape(M, N) - if _TRACE_ENABLED: - _trace_call("_fused_rms_norm") - mod = _get_mod("rmsnorm") - y, rstd, _ = mod.rmsnorm_forward( + from kernelagent_oink.blackwell.rmsnorm import rmsnorm_forward + + y, rstd, _ = rmsnorm_forward( x, weight=weight, bias=None, residual=None, eps=eps, store_rstd=True, ) - y = y.reshape(input_shape) - rstd = rstd.view(_stat_shape(input_shape, len(normalized_shape))) + y = y.reshape(orig_shape) + stat_shape = list(orig_shape[: -len(normalized_shape)]) + [1] * len( + normalized_shape + ) + rstd = rstd.view(stat_shape) return y, rstd -def _oink_fused_rms_norm_backward( +# ========================================================================= +# RMSNorm backward +# ========================================================================= + + +def _fused_rms_norm_backward_impl( + dispatch_keys: torch.DispatchKeySet, grad_out: torch.Tensor, input: torch.Tensor, normalized_shape: List[int], rstd: torch.Tensor, weight: Optional[torch.Tensor], output_mask: List[bool], + *, + fallback_kernel, ) -> Tuple[torch.Tensor, torch.Tensor]: if not _is_supported(input): - return _call_fallback( - "_fused_rms_norm_backward", + return fallback_kernel.call_boxed( + dispatch_keys, grad_out, input, normalized_shape, rstd, weight, output_mask, ) - N = math.prod(normalized_shape) - M = input.numel() // N - - x = _reshape_2d(input, M, N) - dout = _reshape_2d(grad_out, M, N) - rstd_flat = _flatten_1d(rstd, M) - if _TRACE_ENABLED: _trace_call("_fused_rms_norm_backward") - mod = _get_mod("rmsnorm") - dx, dw, _dbias, _dres = mod.rmsnorm_backward( - x, weight, dout, rstd_flat, - dresidual_out=None, has_bias=False, has_residual=False, - ) - - dx = dx.reshape(input.shape) - - if not output_mask[0]: - dx = torch.zeros_like(input) - if not output_mask[1] or dw is None: - dw = torch.zeros_like(weight) if weight is not None else torch.empty(0) - - return dx, dw - -def _oink_native_layer_norm( - input: torch.Tensor, - normalized_shape: List[int], - weight: Optional[torch.Tensor], - bias: Optional[torch.Tensor], - eps: float, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - if not _is_supported(input): - return _call_fallback( - "native_layer_norm", input, normalized_shape, weight, bias, eps - ) - - input_shape = input.shape N = math.prod(normalized_shape) M = input.numel() // N - x = _reshape_2d_checked(input, M, N) - if x is None: - return _call_fallback( - "native_layer_norm", input, normalized_shape, weight, bias, eps - ) - - if _TRACE_ENABLED: - _trace_call("native_layer_norm") - mod = _get_mod("layernorm") - out, rstd, mean = mod.layernorm( - x, weight, bias=bias, eps=eps, return_rstd=True, return_mean=True, - ) - - out = out.reshape(input_shape) - stat_sh = _stat_shape(input_shape, len(normalized_shape)) - mean = mean.view(stat_sh) - rstd = rstd.view(stat_sh) - return out, mean, rstd - + x = input.reshape(M, N).contiguous() + dout = grad_out.reshape(M, N).contiguous() + rstd_flat = rstd.reshape(M).contiguous() -def _oink_native_layer_norm_backward( - grad_out: torch.Tensor, - input: torch.Tensor, - normalized_shape: List[int], - mean: torch.Tensor, - rstd: torch.Tensor, - weight: Optional[torch.Tensor], - bias: Optional[torch.Tensor], - output_mask: List[bool], -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - if not _is_supported(input): - return _call_fallback( - "native_layer_norm_backward", - grad_out, input, normalized_shape, mean, rstd, weight, bias, - output_mask, - ) - - N = math.prod(normalized_shape) - M = input.numel() // N - - x = _reshape_2d(input, M, N) - dout = _reshape_2d(grad_out, M, N) - mean_flat = _flatten_1d(mean, M) - rstd_flat = _flatten_1d(rstd, M) + from kernelagent_oink.blackwell.rmsnorm import rmsnorm_backward - if _TRACE_ENABLED: - _trace_call("native_layer_norm_backward") - mod = _get_mod("layernorm") - dx, dw, db = mod.layernorm_backward( - dout, x, weight, rstd_flat, mean_flat, bias=bias, + dx, dw, _dbias, _dres = rmsnorm_backward( + x, weight, dout, rstd_flat, + dresidual_out=None, has_bias=False, has_residual=False, ) - dx = dx.reshape(input.shape) if dx is not None else torch.zeros_like(input) + dx = dx.reshape(input.shape) if not output_mask[0]: dx = torch.zeros_like(input) if not output_mask[1] or dw is None: dw = torch.zeros_like(weight) if weight is not None else torch.empty(0) - if not output_mask[2] or db is None: - db = torch.zeros_like(bias) if bias is not None else torch.empty(0) - - return dx, dw, db - - -def _oink_softmax( - self: torch.Tensor, - dim: int, - half_to_float: bool, -) -> torch.Tensor: - # Oink's softmax only handles the last dimension on 2D inputs. - # Fall back for other dims or when half_to_float is requested. - ndim = self.ndim - actual_dim = dim if dim >= 0 else dim + ndim - - if not _is_supported(self) or actual_dim != ndim - 1 or half_to_float: - return _call_fallback("_softmax", self, dim, half_to_float) - - input_shape = self.shape - N = input_shape[-1] - M = self.numel() // N - - x = _reshape_2d_checked(self, M, N) - if x is None: - return _call_fallback("_softmax", self, dim, half_to_float) - - if _TRACE_ENABLED: - _trace_call("_softmax") - mod = _get_mod("softmax") - y = mod.softmax_forward(x) - - return y.reshape(input_shape) - - -def _oink_softmax_backward( - grad_output: torch.Tensor, - output: torch.Tensor, - dim: int, - input_dtype: torch.dtype, -) -> torch.Tensor: - ndim = output.ndim - actual_dim = dim if dim >= 0 else dim + ndim - - if ( - not _is_supported(output) - or actual_dim != ndim - 1 - or input_dtype != output.dtype # half_to_float case - ): - return _call_fallback( - "_softmax_backward_data", grad_output, output, dim, input_dtype - ) - - input_shape = output.shape - N = input_shape[-1] - M = output.numel() // N - dy = _reshape_2d(grad_output, M, N) - y = _reshape_2d(output, M, N) - - if _TRACE_ENABLED: - _trace_call("_softmax_backward_data") - mod = _get_mod("softmax") - dx = mod.softmax_backward(dy, y) - - return dx.reshape(input_shape) + return dx, dw +# ========================================================================= +# Registration +# ========================================================================= _ATEN_LIB: torch.library.Library | None = None -# Mapping: (aten_op_name, impl_function) -_OVERRIDES = [ - ("_fused_rms_norm", _oink_fused_rms_norm), - ("_fused_rms_norm_backward", _oink_fused_rms_norm_backward), - ("native_layer_norm", _oink_native_layer_norm), - ("native_layer_norm_backward", _oink_native_layer_norm_backward), - ("_softmax", _oink_softmax), - ("_softmax_backward_data", _oink_softmax_backward), -] - def override_all_aten_kernels() -> None: - """Patch all supported aten ops on the CUDA dispatch key to use Oink's - SM100 CuTeDSL kernels. + """Override ``aten::_fused_rms_norm`` on CUDA with oink's RMSNorm. - Idempotent — safe to call multiple times. Captures the original CUDA - kernels before overriding so that unsupported inputs (wrong dtype, older - GPU) fall back transparently. + Uses ``with_keyset=True`` (quack PR pattern) so the override receives + ``DispatchKeySet`` and can call the original kernel via ``call_boxed`` + for unsupported inputs — no Python wrapper overhead on the fallback path. """ global _ATEN_LIB if _ATEN_LIB is not None: return - # Capture original kernels *before* we overwrite them. - for op_name, _ in _OVERRIDES: - _capture_fallback(op_name) - - lib = torch.library.Library("aten", "IMPL") - registered = [] + fwd_fallback = torch.library.get_kernel("aten::_fused_rms_norm", "CUDA") + bwd_fallback = torch.library.get_kernel( + "aten::_fused_rms_norm_backward", "CUDA" + ) - for op_name, impl_fn in _OVERRIDES: - try: - lib.impl(op_name, impl_fn, "CUDA") - registered.append(op_name) - except Exception as e: - logger.warning("Oink: could not override aten::%s: %s", op_name, e) + fwd_impl = functools.partial( + _fused_rms_norm_impl, fallback_kernel=fwd_fallback + ) + bwd_impl = functools.partial( + _fused_rms_norm_backward_impl, fallback_kernel=bwd_fallback + ) + lib = torch.library.Library("aten", "IMPL") + lib.impl("_fused_rms_norm", fwd_impl, "CUDA", with_keyset=True) + lib.impl("_fused_rms_norm_backward", bwd_impl, "CUDA", with_keyset=True) _ATEN_LIB = lib - logger.info("Oink: overrode %d aten ops on CUDA: %s", len(registered), registered) + logger.info("Oink: overrode aten::_fused_rms_norm on CUDA (with_keyset)") def restore_all_aten_kernels() -> None: - """Remove all Oink overrides and restore PyTorch's native CUDA kernels.""" + """Remove the override and restore PyTorch's native CUDA RMSNorm.""" global _ATEN_LIB if _ATEN_LIB is None: return _ATEN_LIB = None - logger.info("Oink: restored all aten ops to PyTorch defaults") + logger.info("Oink: restored aten::_fused_rms_norm to PyTorch default") -# Keep the old single-op API for backward compatibility. +# Backward-compatible aliases. override_aten_rmsnorm = override_all_aten_kernels restore_aten_rmsnorm = restore_all_aten_kernels diff --git a/oink/src/kernelagent_oink/blackwell/rmsnorm.py b/oink/src/kernelagent_oink/blackwell/rmsnorm.py index ddfe527..ebacbe9 100644 --- a/oink/src/kernelagent_oink/blackwell/rmsnorm.py +++ b/oink/src/kernelagent_oink/blackwell/rmsnorm.py @@ -3222,6 +3222,114 @@ def _rmsnorm_forward_ptr_into( ) +# --------------------------------------------------------------------------- +# Pre-compiled kernel API (quack-style _compile_rmsnorm_fwd) +# --------------------------------------------------------------------------- + +_COMPILE_FWD_CACHE: dict[tuple, object] = {} + + +def _make_fake_tensor( + dtype: type, shape: tuple, divisibility: int = 1, leading_dim: int = -1 +): + """Create a symbolic tensor for cute.compile (same as quack's fake_tensor).""" + if leading_dim < 0: + leading_dim = len(shape) + leading_dim + if dtype is None: + return None + stride = tuple( + cute.sym_int64(divisibility=divisibility) if i != leading_dim else 1 + for i in range(len(shape)) + ) + return cute.runtime.make_fake_tensor( + dtype, shape, stride=stride, + assumed_align=divisibility * dtype.width // 8, + ) + + +def _compile_rmsnorm_fwd( + dtype: type, + out_dtype: type | None, + bias_dtype: type | None, + weight_dtype: type | None, + residual_dtype: type | None, + residual_out_dtype: type | None, + N: int, + store_rstd: bool, + has_residual: bool, + has_bias: bool, +) -> object: + """Pre-compile the RMSNorm forward kernel and return a callable. + + Returns a TVM-FFI compiled callable with signature:: + + kernel(x, weight, bias, residual, out, residual_out, rstd, mean, eps) + + where each argument is a torch.Tensor (or None) and eps is a float. + DLPack tensor conversion happens at C++ level — zero Python overhead + per call. This matches the quack ``_compile_rmsnorm_fwd`` API. + """ + key = ( + dtype, out_dtype, bias_dtype, weight_dtype, + residual_dtype, residual_out_dtype, N, + store_rstd, has_residual, has_bias, + ) + cached = _COMPILE_FWD_CACHE.get(key) + if cached is not None: + return cached + + # Resolve schedule once. + direct_gmem = _direct_gmem_from_policy( + default=bool(dtype.width == 16 and N in {128, 4096, 6144, 7168, 8192}) + ) + use_async = not direct_gmem + copy_bits = _copy_bits_from_policy(default=128, can_use_256=False) + stage = 1 + + op = RMSNormSM100( + N, + dtype, + stage=stage, + copy_bits=int(copy_bits), + use_async=bool(use_async), + direct_gmem=bool(direct_gmem), + ) + + # Build symbolic tensors for compilation (same as quack's fake_tensor). + batch_sym = cute.sym_int() + all_dtypes = [dtype, out_dtype, residual_dtype, weight_dtype, + bias_dtype, residual_out_dtype] + import math as _math + div = _math.gcd(N, *(128 // dt.width for dt in all_dtypes if dt is not None)) + + x_fake = _make_fake_tensor(dtype, (batch_sym, N), div) + out_fake = _make_fake_tensor(out_dtype or dtype, (batch_sym, N), div) + res_fake = _make_fake_tensor(residual_dtype, (batch_sym, N), div) if has_residual else None + res_out_fake = _make_fake_tensor(residual_out_dtype, (batch_sym, N), div) if residual_out_dtype else None + w_fake = _make_fake_tensor(weight_dtype, (N,), div) if weight_dtype else None + b_fake = _make_fake_tensor(bias_dtype, (N,), div) if bias_dtype else None + rstd_fake = _make_fake_tensor(cutlass.Float32, (batch_sym,)) if store_rstd else None + + # Compile with TVM FFI — the returned callable accepts torch tensors + # directly via DLPack at C++ level. + compiled = cute.compile( + op, + x_fake, + w_fake, + b_fake, + res_fake, + out_fake, + res_out_fake, + rstd_fake, + cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), + Float32(0), # eps placeholder + options="--enable-tvm-ffi", + ) + + _COMPILE_FWD_CACHE[key] = compiled + return compiled + + def _fused_add_rmsnorm_forward_ptr_inplace( x: Tensor, residual: Tensor, From 9b44197e4a1dfcc78185387886e9bb21d2bbe804 Mon Sep 17 00:00:00 2001 From: Jiannan Wang Date: Tue, 21 Apr 2026 15:51:02 -0700 Subject: [PATCH 6/6] Add register_all_kernels API and aten RMSNorm override with benchmarks MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds kernelagent_oink.register_all_kernels(force=True) which overrides aten::_fused_rms_norm and aten::_fused_rms_norm_backward at the CUDA dispatch key. Follows the quack PR pattern (with_keyset=True, call_boxed fallback). Calls rmsnorm_forward/rmsnorm_backward directly for all kernel optimizations (ptr fast-launch, atomic dW, _reduce_partial_sum). No changes to rmsnorm.py — the override is purely in the registration layer. Backward returns None for masked outputs, matching native behavior. unregister_all_kernels() allows clean restore and re-registration. Includes benchmark scripts: - benchmark_rmsnorm_oink_vs_aten.py: kernel-level (do_bench_triton) - benchmark_rmsnorm_dispatch.py: dispatch-level (aten API, aten vs quack vs oink) - benchmark_cudagraph_all.py: CUDA graph (zero Python overhead) - benchmark_quack_overhead.py: quack PR overhead analysis GB200 results (bf16, CUDA graph, forward): (1, 8192): oink 2.54x faster than aten (4096, 8192): oink 1.28x faster than aten (65536, 8192): oink 1.18x faster than aten Oink matches or beats quack at all shapes (O/Q >= 1.00x) Without CUDA graph, Python dispatch overhead from rmsnorm_forward (~0.08ms/call) dominates at small shapes. With CUDA graph or torch.compile, the overhead is eliminated and the kernel advantage shows through. --- .../benchmark/benchmark_cudagraph_all.py | 105 +++++ .../benchmark/benchmark_quack_overhead.py | 103 +++++ .../benchmark/benchmark_rmsnorm_dispatch.py | 104 +++++ .../benchmark/run_benchmark_cudagraph_all.sh | 62 +++ .../benchmark/run_benchmark_quack_overhead.sh | 112 +++++ .../run_benchmark_rmsnorm_dispatch.sh | 83 ++++ oink/src/kernelagent_oink/__init__.py | 122 +++--- oink/src/kernelagent_oink/aten_override.py | 159 ++++---- .../src/kernelagent_oink/blackwell/rmsnorm.py | 108 ----- oink/tests/test_aten_override.py | 382 ++++-------------- 10 files changed, 793 insertions(+), 547 deletions(-) create mode 100644 oink/benchmarks/benchmark/benchmark_cudagraph_all.py create mode 100644 oink/benchmarks/benchmark/benchmark_quack_overhead.py create mode 100644 oink/benchmarks/benchmark/benchmark_rmsnorm_dispatch.py create mode 100755 oink/benchmarks/benchmark/run_benchmark_cudagraph_all.sh create mode 100755 oink/benchmarks/benchmark/run_benchmark_quack_overhead.sh create mode 100755 oink/benchmarks/benchmark/run_benchmark_rmsnorm_dispatch.sh diff --git a/oink/benchmarks/benchmark/benchmark_cudagraph_all.py b/oink/benchmarks/benchmark/benchmark_cudagraph_all.py new file mode 100644 index 0000000..5532641 --- /dev/null +++ b/oink/benchmarks/benchmark/benchmark_cudagraph_all.py @@ -0,0 +1,105 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +CUDA graph benchmark: aten vs quack vs oink RMSNorm. + +All calls go through torch.ops.aten._fused_rms_norm. CUDA graphs +eliminate Python dispatch overhead, isolating pure kernel performance. + +Usage:: + + bash oink/benchmarks/benchmark/run_benchmark_cudagraph_all.sh +""" + +from __future__ import annotations + +import argparse +import json +import os + +os.environ.setdefault("TORCH_NATIVE_SKIP_VERSION_CHECK", "1") + +import torch +from triton.testing import do_bench + +SHAPES = [ + (1, 4096), + (1, 8192), + (32, 4096), + (32, 8192), + (256, 4096), + (256, 8192), + (1024, 4096), + (1024, 8192), + (4096, 4096), + (4096, 8192), + (16384, 4096), + (16384, 8192), + (65536, 4096), + (65536, 8192), +] +DTYPE = torch.bfloat16 + + +def bench_cudagraph(fn, warmup=50, rep=200): + """Capture fn into a CUDA graph, then benchmark replay.""" + for _ in range(warmup): + fn() + torch.cuda.synchronize() + + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + fn() + torch.cuda.synchronize() + + return do_bench(lambda: g.replay(), warmup=10, rep=rep, return_mode="median") + + +def main(): + p = argparse.ArgumentParser() + p.add_argument("--mode", choices=["aten", "quack", "oink"], required=True) + args = p.parse_args() + + if args.mode == "oink": + import kernelagent_oink + kernelagent_oink.register_all_kernels(force=True) + + # Warm up + for M, N in SHAPES: + x = torch.randn(M, N, dtype=DTYPE, device="cuda") + w = torch.randn(N, dtype=DTYPE, device="cuda") + torch.ops.aten._fused_rms_norm(x, [N], w, 1e-5) + torch.cuda.synchronize() + + results = {} + for M, N in SHAPES: + x = torch.randn(M, N, dtype=DTYPE, device="cuda") + w = torch.randn(N, dtype=DTYPE, device="cuda") + + def fn_fwd(x=x, w=w, N=N): + return torch.ops.aten._fused_rms_norm(x, [N], w, 1e-5) + + try: + fwd_ms = bench_cudagraph(fn_fwd) + except Exception: + fwd_ms = -1.0 + + results[f"{M}x{N}"] = {"fwd": fwd_ms} + + print(json.dumps({"mode": args.mode, "results": results})) + + +if __name__ == "__main__": + main() diff --git a/oink/benchmarks/benchmark/benchmark_quack_overhead.py b/oink/benchmarks/benchmark/benchmark_quack_overhead.py new file mode 100644 index 0000000..c7ef501 --- /dev/null +++ b/oink/benchmarks/benchmark/benchmark_quack_overhead.py @@ -0,0 +1,103 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Overhead analysis for the quack RMSNorm PR (pytorch#178326). + +Measures aten vs quack through the same ``torch.ops.aten._fused_rms_norm`` +API at various shapes. Quack is registered via ``torch._native`` (the PR +pattern). Run with ``--mode=aten`` or ``--mode=quack`` in separate processes +to avoid cross-contamination. + +Usage:: + + bash oink/benchmarks/benchmark/run_benchmark_quack_overhead.sh +""" + +from __future__ import annotations + +import argparse +import json +import os + +os.environ.setdefault("TORCH_NATIVE_SKIP_VERSION_CHECK", "1") + +import torch +from triton.testing import do_bench + +# Comprehensive shape grid: small → large M, production N values. +SHAPES = [ + # Small M (dispatch overhead dominates) + (1, 4096), + (1, 8192), + (32, 4096), + (32, 8192), + (256, 4096), + (256, 8192), + # Medium M (crossover region) + (1024, 4096), + (1024, 8192), + (4096, 4096), + (4096, 8192), + # Large M (kernel compute dominates) + (16384, 4096), + (16384, 8192), + (65536, 4096), + (65536, 8192), +] +DTYPE = torch.bfloat16 + + +def bench(fn, warmup=50, rep=200): + return do_bench(fn, warmup=warmup, rep=rep, return_mode="median") + + +def main(): + p = argparse.ArgumentParser() + p.add_argument("--mode", choices=["aten", "quack"], required=True) + args = p.parse_args() + + # Warm up (triggers JIT compilation for quack). + for M, N in SHAPES: + x = torch.randn(M, N, dtype=DTYPE, device="cuda") + w = torch.randn(N, dtype=DTYPE, device="cuda") + torch.ops.aten._fused_rms_norm(x, [N], w, 1e-5) + torch.cuda.synchronize() + + results = {} + for M, N in SHAPES: + x = torch.randn(M, N, dtype=DTYPE, device="cuda", requires_grad=True) + w = torch.randn(N, dtype=DTYPE, device="cuda", requires_grad=True) + grad = torch.randn(M, N, dtype=DTYPE, device="cuda") + + def fn_fwd(x=x, w=w, N=N): + return torch.ops.aten._fused_rms_norm(x, [N], w, 1e-5) + + fwd_ms = bench(fn_fwd) + + x_ = x.detach().requires_grad_(True) + w_ = w.detach().requires_grad_(True) + + def fn_fwdbwd(x_=x_, w_=w_, N=N, grad=grad): + y, _ = torch.ops.aten._fused_rms_norm(x_, [N], w_, 1e-5) + y.backward(grad) + + fwdbwd_ms = bench(fn_fwdbwd) + results[f"{M}x{N}"] = {"fwd": fwd_ms, "fwdbwd": fwdbwd_ms} + + print(json.dumps({"mode": args.mode, "results": results})) + + +if __name__ == "__main__": + main() diff --git a/oink/benchmarks/benchmark/benchmark_rmsnorm_dispatch.py b/oink/benchmarks/benchmark/benchmark_rmsnorm_dispatch.py new file mode 100644 index 0000000..1f87209 --- /dev/null +++ b/oink/benchmarks/benchmark/benchmark_rmsnorm_dispatch.py @@ -0,0 +1,104 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Benchmark aten vs quack vs oink RMSNorm through the same PyTorch API. + +All three are called via ``torch.ops.aten._fused_rms_norm``. Quack is +registered via ``torch._native`` (requires the quack PR in pytorch). +Oink is registered via ``kernelagent_oink.register_all_kernels``. +Aten is the unoverridden baseline. + +This script must be invoked three times with ``--mode={aten,quack,oink}`` +by the companion ``run_benchmark_rmsnorm_dispatch.sh`` script, which +swaps ``torch._native/ops/norm/__init__.py`` between runs. + +Usage:: + + bash oink/benchmarks/benchmark/run_benchmark_rmsnorm_dispatch.sh +""" + +from __future__ import annotations + +import argparse +import json +import os + +os.environ.setdefault("TORCH_NATIVE_SKIP_VERSION_CHECK", "1") + +import torch +from triton.testing import do_bench + +SHAPES = [ + (4096, 4096), + (4096, 8192), + (16384, 4096), + (16384, 8192), + (65536, 4096), + (65536, 8192), +] +DTYPE = torch.bfloat16 + + +def bench(fn, warmup=50, rep=200): + return do_bench(fn, warmup=warmup, rep=rep, return_mode="median") + + +def main(): + p = argparse.ArgumentParser( + description="Benchmark aten/quack/oink RMSNorm through aten API." + ) + p.add_argument("--mode", choices=["aten", "quack", "oink"], required=True) + args = p.parse_args() + + if args.mode == "oink": + import kernelagent_oink + + kernelagent_oink.register_all_kernels(force=True) + + # Warm up (triggers JIT compilation for quack/oink). + for M, N in SHAPES: + x = torch.randn(M, N, dtype=DTYPE, device="cuda") + w = torch.randn(N, dtype=DTYPE, device="cuda") + torch.ops.aten._fused_rms_norm(x, [N], w, 1e-5) + torch.cuda.synchronize() + + results = {} + for M, N in SHAPES: + x = torch.randn(M, N, dtype=DTYPE, device="cuda", requires_grad=True) + w = torch.randn(N, dtype=DTYPE, device="cuda", requires_grad=True) + grad = torch.randn(M, N, dtype=DTYPE, device="cuda") + + # Forward only. + def fn_fwd(x=x, w=w, N=N): + return torch.ops.aten._fused_rms_norm(x, [N], w, 1e-5) + + fwd_ms = bench(fn_fwd) + + # Forward + backward. + x_ = x.detach().requires_grad_(True) + w_ = w.detach().requires_grad_(True) + + def fn_fwdbwd(x_=x_, w_=w_, N=N, grad=grad): + y, _ = torch.ops.aten._fused_rms_norm(x_, [N], w_, 1e-5) + y.backward(grad) + + fwdbwd_ms = bench(fn_fwdbwd) + results[f"{M}x{N}"] = {"fwd": fwd_ms, "fwdbwd": fwdbwd_ms} + + print(json.dumps({"mode": args.mode, "results": results})) + + +if __name__ == "__main__": + main() diff --git a/oink/benchmarks/benchmark/run_benchmark_cudagraph_all.sh b/oink/benchmarks/benchmark/run_benchmark_cudagraph_all.sh new file mode 100755 index 0000000..ec122ff --- /dev/null +++ b/oink/benchmarks/benchmark/run_benchmark_cudagraph_all.sh @@ -0,0 +1,62 @@ +#!/bin/bash +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +NORM_DIR="${TORCH_NATIVE_NORM_DIR:-$(python -c 'import torch, pathlib; print(pathlib.Path(torch.__file__).parent / "_native/ops/norm")' 2>/dev/null)}" + +RESULTS_DIR="${RESULTS_DIR:-/tmp}" + +echo "CUDA Graph Benchmark: Aten vs Quack vs Oink RMSNorm" +echo "====================================================" +echo "" + +# Aten +echo "Running aten..." +echo "" > "$NORM_DIR/__init__.py" +TORCH_NATIVE_SKIP_VERSION_CHECK=1 python "$SCRIPT_DIR/benchmark_cudagraph_all.py" \ + --mode=aten > "$RESULTS_DIR/cudagraph_all_aten.json" 2>/dev/null + +# Quack +echo "Running quack..." +echo "from . import rmsnorm_impl # noqa: F401" > "$NORM_DIR/__init__.py" +TORCH_NATIVE_SKIP_VERSION_CHECK=1 python "$SCRIPT_DIR/benchmark_cudagraph_all.py" \ + --mode=quack > "$RESULTS_DIR/cudagraph_all_quack.json" 2>/dev/null + +# Oink +echo "Running oink..." +echo "" > "$NORM_DIR/__init__.py" +TORCH_NATIVE_SKIP_VERSION_CHECK=1 python "$SCRIPT_DIR/benchmark_cudagraph_all.py" \ + --mode=oink > "$RESULTS_DIR/cudagraph_all_oink.json" 2>/dev/null + +# Restore +echo "from . import rmsnorm_impl # noqa: F401" > "$NORM_DIR/__init__.py" + +python3 -c " +import json + +aten = json.loads(open('$RESULTS_DIR/cudagraph_all_aten.json').read())['results'] +quack = json.loads(open('$RESULTS_DIR/cudagraph_all_quack.json').read())['results'] +oink = json.loads(open('$RESULTS_DIR/cudagraph_all_oink.json').read())['results'] + +print() +print('Forward (CUDA graph, bf16):') +print('┌──────────────────┬───────────┬───────────┬───────────┬─────────┬─────────┬─────────┐') +print('│ Shape │ Aten (ms) │ Quack (ms)│ Oink (ms) │ Q vs A │ O vs A │ O vs Q │') +print('├──────────────────┼───────────┼───────────┼───────────┼─────────┼─────────┼─────────┤') +for shape in aten: + M, N = shape.split('x') + a = aten[shape]['fwd'] + q = quack[shape]['fwd'] + o = oink[shape]['fwd'] + def fmt(v): + return f'{v:9.4f}' if v > 0 else ' FAIL' + def ratio(num, den): + if den <= 0 or num <= 0: + return ' N/A ' + return f'{num/den:7.2f}x' + print(f'│ ({M:>5s}, {N:>5s}) │ {fmt(a)} │ {fmt(q)} │ {fmt(o)} │ {ratio(a,q)} │ {ratio(a,o)} │ {ratio(q,o)} │') +print('└──────────────────┴───────────┴───────────┴───────────┴─────────┴─────────┴─────────┘') +" + +echo "" +echo "Done." diff --git a/oink/benchmarks/benchmark/run_benchmark_quack_overhead.sh b/oink/benchmarks/benchmark/run_benchmark_quack_overhead.sh new file mode 100755 index 0000000..2c2baca --- /dev/null +++ b/oink/benchmarks/benchmark/run_benchmark_quack_overhead.sh @@ -0,0 +1,112 @@ +#!/bin/bash +# Overhead analysis for the quack RMSNorm PR (pytorch#178326). +# +# Compares aten baseline vs quack override through the same +# torch.ops.aten._fused_rms_norm API at various shapes. +# +# Usage: +# cd KernelAgent +# conda activate nanoGPT +# bash oink/benchmarks/benchmark/run_benchmark_quack_overhead.sh + +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +NORM_DIR="${TORCH_NATIVE_NORM_DIR:-$(python -c 'import torch, pathlib; print(pathlib.Path(torch.__file__).parent / "_native/ops/norm")' 2>/dev/null)}" + +if [ -z "$NORM_DIR" ] || [ ! -d "$NORM_DIR" ]; then + echo "ERROR: torch._native/ops/norm/ not found." + exit 1 +fi + +RESULTS_DIR="${RESULTS_DIR:-/tmp}" + +echo "Quack RMSNorm PR Overhead Analysis" +echo "===================================" +echo "norm dir: $NORM_DIR" +echo "" + +# --- Aten baseline (no override) --- +echo "Running aten baseline..." +echo "" > "$NORM_DIR/__init__.py" +TORCH_NATIVE_SKIP_VERSION_CHECK=1 python "$SCRIPT_DIR/benchmark_quack_overhead.py" \ + --mode=aten > "$RESULTS_DIR/quack_overhead_aten.json" 2>/dev/null + +# --- Quack via torch._native --- +echo "Running quack override..." +echo "from . import rmsnorm_impl # noqa: F401" > "$NORM_DIR/__init__.py" +TORCH_NATIVE_SKIP_VERSION_CHECK=1 python "$SCRIPT_DIR/benchmark_quack_overhead.py" \ + --mode=quack > "$RESULTS_DIR/quack_overhead_quack.json" 2>/dev/null + +# --- Restore --- +echo "from . import rmsnorm_impl # noqa: F401" > "$NORM_DIR/__init__.py" + +# --- Print report --- +python3 -c " +import json + +aten = json.loads(open('$RESULTS_DIR/quack_overhead_aten.json').read())['results'] +quack = json.loads(open('$RESULTS_DIR/quack_overhead_quack.json').read())['results'] + +def print_table(title, key): + print(f'{title}') + print('┌──────────────────┬───────────┬───────────┬──────────┬────────────┐') + print('│ Shape │ Aten (ms) │ Quack (ms)│ Quack/A │ Overhead │') + print('├──────────────────┼───────────┼───────────┼──────────┼────────────┤') + for shape in aten: + M, N = shape.split('x') + a = aten[shape][key] + q = quack[shape][key] + ratio = a / q + overhead_ms = q - a + marker = ' ✓ faster' if ratio > 1.0 else ' ✗ slower' + print(f'│ ({M:>5s}, {N:>5s}) │ {a:>9.4f} │ {q:>9.4f} │ {ratio:>7.2f}x │ {overhead_ms:>+8.4f}ms│') + print('└──────────────────┴───────────┴───────────┴──────────┴────────────┘') + print() + +print() +print_table('Forward only:', 'fwd') +print_table('Forward + Backward:', 'fwdbwd') + +# Summary +print('Summary:') +print('--------') +fwd_crossover = None +bwd_crossover = None +for shape in aten: + M, N = shape.split('x') + M = int(M) + a_fwd = aten[shape]['fwd'] + q_fwd = quack[shape]['fwd'] + a_bwd = aten[shape]['fwdbwd'] + q_bwd = quack[shape]['fwdbwd'] + if fwd_crossover is None and a_fwd / q_fwd >= 1.0: + fwd_crossover = (M, int(N)) + if bwd_crossover is None and a_bwd / q_bwd >= 1.0: + bwd_crossover = (M, int(N)) + +if fwd_crossover: + print(f' Forward crossover (quack >= aten): M={fwd_crossover[0]}, N={fwd_crossover[1]}') +else: + print(f' Forward: quack is slower than aten at all tested shapes') +if bwd_crossover: + print(f' Fwd+Bwd crossover (quack >= aten): M={bwd_crossover[0]}, N={bwd_crossover[1]}') +else: + print(f' Fwd+Bwd: quack is slower than aten at all tested shapes') + +# Overhead analysis +print() +print('Overhead analysis:') +small_fwd = [quack[s]['fwd'] - aten[s]['fwd'] for s in list(aten)[:6]] +small_bwd = [quack[s]['fwdbwd'] - aten[s]['fwdbwd'] for s in list(aten)[:6]] +avg_fwd_overhead = sum(small_fwd) / len(small_fwd) +avg_bwd_overhead = sum(small_bwd) / len(small_bwd) +print(f' Avg fwd overhead at small M (1-256): {avg_fwd_overhead:+.4f} ms/call') +print(f' Avg fwd+bwd overhead at small M (1-256): {avg_bwd_overhead:+.4f} ms/call') +print(f' This overhead is from Python dispatch through torch._native:') +print(f' _fused_rms_norm_impl → quack_rmsnorm_fwd → _compile_rmsnorm_fwd → kernel') +print(f' At large M, the faster CuTeDSL kernel overcomes this overhead.') +" + +echo "" +echo "Done." diff --git a/oink/benchmarks/benchmark/run_benchmark_rmsnorm_dispatch.sh b/oink/benchmarks/benchmark/run_benchmark_rmsnorm_dispatch.sh new file mode 100755 index 0000000..95a632c --- /dev/null +++ b/oink/benchmarks/benchmark/run_benchmark_rmsnorm_dispatch.sh @@ -0,0 +1,83 @@ +#!/bin/bash +# Benchmark aten vs quack vs oink RMSNorm through the same aten API. +# +# Quack is registered via torch._native (requires quack PR in pytorch). +# Oink is registered via kernelagent_oink.register_all_kernels. +# +# Usage: +# bash oink/benchmarks/benchmark/run_benchmark_rmsnorm_dispatch.sh +# +# Prerequisites: +# - conda env "nanoGPT" with torch, quack-kernels==0.3.7, kernelagent-oink +# - torch._native infrastructure installed in the torch package +# - `import torch._native` at end of torch/__init__.py + +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +NORM_DIR="${TORCH_NATIVE_NORM_DIR:-$(python -c 'import torch, pathlib; print(pathlib.Path(torch.__file__).parent / "_native/ops/norm")' 2>/dev/null)}" + +if [ -z "$NORM_DIR" ] || [ ! -d "$NORM_DIR" ]; then + echo "ERROR: torch._native/ops/norm/ not found. Set TORCH_NATIVE_NORM_DIR or install torch._native." + exit 1 +fi + +RESULTS_DIR="${RESULTS_DIR:-/tmp}" + +echo "Using norm dir: $NORM_DIR" +echo "Results dir: $RESULTS_DIR" +echo "" + +# --- Aten baseline (no override) --- +echo "Benchmarking aten..." +echo "" > "$NORM_DIR/__init__.py" +TORCH_NATIVE_SKIP_VERSION_CHECK=1 python "$SCRIPT_DIR/benchmark_rmsnorm_dispatch.py" \ + --mode=aten > "$RESULTS_DIR/dispatch_aten.json" 2>/dev/null + +# --- Quack via torch._native --- +echo "Benchmarking quack..." +echo "from . import rmsnorm_impl # noqa: F401" > "$NORM_DIR/__init__.py" +TORCH_NATIVE_SKIP_VERSION_CHECK=1 python "$SCRIPT_DIR/benchmark_rmsnorm_dispatch.py" \ + --mode=quack > "$RESULTS_DIR/dispatch_quack.json" 2>/dev/null + +# --- Oink via aten_override (register_all_kernels) --- +echo "Benchmarking oink..." +echo "" > "$NORM_DIR/__init__.py" +TORCH_NATIVE_SKIP_VERSION_CHECK=1 python "$SCRIPT_DIR/benchmark_rmsnorm_dispatch.py" \ + --mode=oink > "$RESULTS_DIR/dispatch_oink.json" 2>/dev/null + +# --- Restore --- +echo "from . import rmsnorm_impl # noqa: F401" > "$NORM_DIR/__init__.py" + +# --- Print tables --- +python3 -c " +import json, sys + +aten = json.loads(open('$RESULTS_DIR/dispatch_aten.json').read())['results'] +quack = json.loads(open('$RESULTS_DIR/dispatch_quack.json').read())['results'] +oink = json.loads(open('$RESULTS_DIR/dispatch_oink.json').read())['results'] + +def table(title, key): + print(f'**{title}:**') + print('\`\`\`') + print('┌──────────────────┬───────────┬───────────┬───────────┬─────────┬─────────┬─────────┐') + print('│ Shape │ Aten (ms) │ Quack (ms)│ Oink (ms) │ Q vs A │ O vs A │ O vs Q │') + print('├──────────────────┼───────────┼───────────┼───────────┼─────────┼─────────┼─────────┤') + for shape in aten: + M, N = shape.split('x') + a = aten[shape][key] + q = quack[shape][key] + o = oink[shape][key] + qa = a / q + oa = a / o + oq = q / o + print(f'│ ({M:>5s}, {N:>5s}) │ {a:>9.3f} │ {q:>9.3f} │ {o:>9.3f} │ {qa:>6.2f}x │ {oa:>6.2f}x │ {oq:>6.2f}x │') + print('└──────────────────┴───────────┴───────────┴───────────┴─────────┴─────────┴─────────┘') + print('\`\`\`') + print() + +table('Forward', 'fwd') +table('Forward + Backward', 'fwdbwd') +" + +echo "Done." diff --git a/oink/src/kernelagent_oink/__init__.py b/oink/src/kernelagent_oink/__init__.py index 5449481..d61b60e 100644 --- a/oink/src/kernelagent_oink/__init__.py +++ b/oink/src/kernelagent_oink/__init__.py @@ -58,66 +58,72 @@ def _compute_cutedsl_arch(major: int, minor: int) -> str: return f"sm_{major}{minor}{suffix}" -def register(*, force: bool = False) -> None: - """Register Oink torch custom ops. - - - vLLM plugin mode (default): no-op unless `VLLM_USE_OINK_RMSNORM` is truthy. - - Standalone mode: pass `force=True` to register explicitly. +def _check_and_setup() -> bool: + """Check CUDA availability, SM >= 100, CuTeDSL deps, and set CUTE_DSL_ARCH. - This function must be safe to call multiple times and must not raise. vLLM - executes it in multiple processes (engine + workers). + Returns True if all checks pass, False otherwise. Does not raise. """ - global _OPS_REGISTERED - - if _OPS_REGISTERED: - return - - # Gate on the vLLM integration flag so installing the package does not - # change behavior unless explicitly enabled. For standalone usage (outside - # vLLM), callers can pass force=True to register the ops explicitly. - if not force and not _env_truthy("VLLM_USE_OINK_RMSNORM"): - return - try: import torch - except Exception as e: # pragma: no cover + except Exception as e: logger.debug("Oink plugin: torch import failed: %s", e) - return + return False try: if not torch.cuda.is_available(): - logger.debug("Oink plugin: torch.cuda.is_available() is False; skipping") - return + logger.debug("Oink plugin: CUDA not available; skipping") + return False device_index = _infer_cuda_device_index() major, minor = torch.cuda.get_device_capability(device_index) sm = 10 * int(major) + int(minor) if sm < 100: - return + return False - # Ensure required deps are importable before registering ops so that vLLM - # doesn't detect ops that would later fail at first use. try: import cutlass # noqa: F401 import cuda.bindings.driver as _cuda # noqa: F401 except Exception as e: logger.warning( - "Oink plugin: CuTeDSL deps missing; skipping op registration. " + "Oink plugin: CuTeDSL deps missing; skipping. " "Install `nvidia-cutlass-dsl` + `cuda-python`. Error: %s", e, ) - return + return False - # Ensure CuTeDSL sees a target arch early. If the user has already set it, - # respect their choice. os.environ.setdefault( "CUTE_DSL_ARCH", _compute_cutedsl_arch(int(major), int(minor)) ) + return True + except Exception as e: + logger.exception("Oink plugin: setup failed: %s", e) + return False + + +def register(*, force: bool = False) -> None: + """Register Oink torch custom ops (``torch.ops.oink.*``). + + This registers ``torch.ops.oink.rmsnorm`` and + ``torch.ops.oink.fused_add_rms_norm`` for use by vLLM's direct-call path. + It does NOT override aten ops — use :func:`register_all_kernels` for that. + + - vLLM plugin mode (default): no-op unless ``VLLM_USE_OINK_RMSNORM`` is truthy. + - Standalone mode: pass ``force=True`` to register explicitly. + """ + global _OPS_REGISTERED + + if _OPS_REGISTERED: + return + + if not force and not _env_truthy("VLLM_USE_OINK_RMSNORM"): + return - # Import registers the ops via torch.library.custom_op decorators. + if not _check_and_setup(): + return + + try: from .blackwell import oink_custom_ops # noqa: F401 - except Exception as e: # pragma: no cover - # Do not raise: vLLM plugin loader does not guard plugin execution. - logger.exception("Oink plugin: failed to register ops: %s", e) + except Exception as e: + logger.exception("Oink plugin: failed to register custom ops: %s", e) return _OPS_REGISTERED = True @@ -127,48 +133,48 @@ def register(*, force: bool = False) -> None: def register_all_kernels(*, force: bool = False) -> None: - """Register Oink custom ops *and* override PyTorch's native aten operators. + """Override aten ops with Oink's kernels. - This is the main entry point for redirecting standard PyTorch calls - (``F.rms_norm``, ``F.layer_norm``, ``F.softmax``, etc.) to Oink's SM100 - CuTeDSL kernels. It performs two steps: + Checks CUDA/SM100/deps, sets up the CuTeDSL environment, then overrides + ``aten::_fused_rms_norm`` and ``aten::_fused_rms_norm_backward`` on CUDA. - 1. Calls :func:`register` to define ``torch.ops.oink.rmsnorm`` and - ``torch.ops.oink.fused_add_rms_norm`` custom ops. - 2. Patches the following aten ops at the CUDA dispatch key so that - PyTorch's built-in operators transparently use the Oink kernels: - - - ``aten::_fused_rms_norm`` / ``aten::_fused_rms_norm_backward`` - - ``aten::native_layer_norm`` / ``aten::native_layer_norm_backward`` - - ``aten::_softmax`` / ``aten::_softmax_backward_data`` - - Unsupported inputs (wrong dtype, SM < 100) fall back to PyTorch's - original CUDA kernels automatically. + Does NOT register ``torch.ops.oink.*`` custom ops — use :func:`register` + separately if those are needed (e.g. for vLLM's direct-call path). Args: - force: If *True*, register regardless of the - ``VLLM_USE_OINK_RMSNORM`` environment variable. + force: If *True*, bypass the ``VLLM_USE_OINK_RMSNORM`` env gate. """ global _ALL_KERNELS_REGISTERED if _ALL_KERNELS_REGISTERED: return - # Step 1: register torch.ops.oink.* custom ops. - register(force=force) + if not force and not _env_truthy("VLLM_USE_OINK_RMSNORM"): + return - if not _OPS_REGISTERED: + if not _check_and_setup(): return - # Step 2: override aten ops on CUDA. try: - from .aten_override import override_all_aten_kernels + from .aten_override import override_all_kernels - override_all_aten_kernels() - except Exception as e: # pragma: no cover + override_all_kernels() + except Exception as e: logger.exception("Oink: failed to override aten ops: %s", e) return _ALL_KERNELS_REGISTERED = True -__all__ = ["register", "register_all_kernels"] +def unregister_all_kernels() -> None: + """Remove the aten override. Can be followed by :func:`register_all_kernels`.""" + global _ALL_KERNELS_REGISTERED + try: + from .aten_override import restore_all_kernels + + restore_all_kernels() + except Exception: + pass + _ALL_KERNELS_REGISTERED = False + + +__all__ = ["register", "register_all_kernels", "unregister_all_kernels"] diff --git a/oink/src/kernelagent_oink/aten_override.py b/oink/src/kernelagent_oink/aten_override.py index cfcdc95..1dc62f2 100644 --- a/oink/src/kernelagent_oink/aten_override.py +++ b/oink/src/kernelagent_oink/aten_override.py @@ -13,52 +13,46 @@ # limitations under the License. """ -Override ``aten::_fused_rms_norm`` with Oink's Blackwell CuTeDSL RMSNorm. +Override Aten kernels with Oink's Blackwell CuTeDSL Kernels. -Follows the quack PR pattern (``torch/_native/ops/norm/rmsnorm_impl.py``): -registers with ``with_keyset=True`` so the impl receives a -``DispatchKeySet`` and can call the captured fallback via ``call_boxed``. +Currently overrides: +- ``aten::_fused_rms_norm`` → ``rmsnorm_forward`` +- ``aten::_fused_rms_norm_backward`` → ``rmsnorm_backward`` -Usage:: - - import kernelagent_oink - kernelagent_oink.register_all_kernels(force=True) - y = torch.nn.functional.rms_norm(x, [N], weight, eps) # uses Oink +Follows the quack PR pattern: ``with_keyset=True``, fallback via ``call_boxed``. +Calls ``rmsnorm_forward`` / ``rmsnorm_backward`` directly to get all kernel +optimizations (ptr fast-launch, atomic dW, _reduce_partial_sum_fp32). """ from __future__ import annotations -import functools +import importlib import logging import math -import os +from functools import cache, partial from typing import List, Optional, Tuple import torch logger = logging.getLogger(__name__) + # --------------------------------------------------------------------------- -# Optional tracing — zero overhead when disabled (OINK_TRACE=1). +# Lazy imports (cached) # --------------------------------------------------------------------------- -_TRACE_ENABLED: bool = os.environ.get("OINK_TRACE", "0").strip() in ("1", "true") -_call_counts: dict[str, int] = {} - -def _trace_call(op_name: str) -> None: - n = _call_counts.get(op_name, 0) + 1 - _call_counts[op_name] = n - if n == 1: - print(f"[OINK] {op_name} override called (first invocation)", flush=True) +@cache +def _oink_rmsnorm(): + return importlib.import_module("kernelagent_oink.blackwell.rmsnorm") # --------------------------------------------------------------------------- -# Device / dtype support check +# Device support (cached) # --------------------------------------------------------------------------- -@functools.cache +@cache def _get_device_major(device: torch.device) -> int: major, _ = torch.cuda.get_device_capability(device) return major @@ -66,10 +60,27 @@ def _get_device_major(device: torch.device) -> int: def _is_supported(input: torch.Tensor) -> bool: return input.dtype in ( - torch.float16, - torch.bfloat16, - torch.float32, - ) and _get_device_major(input.device) in (9, 10) + torch.float16, torch.bfloat16, torch.float32, + ) and _get_device_major(input.device) >= 10 + + +# --------------------------------------------------------------------------- +# Reshape helpers (match quack's norms.py) +# --------------------------------------------------------------------------- + + +def _reshape_2d(t: torch.Tensor, M: int, N: int) -> torch.Tensor: + if t.ndim == 2 and t.shape[0] == M and t.shape[1] == N and t.is_contiguous(): + return t + return t.reshape(M, N).contiguous() + + +def _flatten_rstd(t: torch.Tensor, M: int) -> torch.Tensor: + if t.ndim == 1 and t.shape[0] == M: + return t + if t.is_contiguous() and t.numel() == M: + return t.detach().view(M) + return t.reshape(M).contiguous() # ========================================================================= @@ -90,27 +101,24 @@ def _fused_rms_norm_impl( return fallback_kernel.call_boxed( dispatch_keys, input, normalized_shape, weight, eps ) - if eps is None: eps = 1e-6 - if _TRACE_ENABLED: - _trace_call("_fused_rms_norm") - - orig_shape = input.shape + input_shape = input.shape N = math.prod(normalized_shape) M = input.numel() // N x = input.reshape(M, N) - from kernelagent_oink.blackwell.rmsnorm import rmsnorm_forward + if weight is not None and weight.ndim != 1: + weight = weight.view(N) - y, rstd, _ = rmsnorm_forward( + y, rstd, _ = _oink_rmsnorm().rmsnorm_forward( x, weight=weight, bias=None, residual=None, eps=eps, store_rstd=True, ) - y = y.reshape(orig_shape) - stat_shape = list(orig_shape[: -len(normalized_shape)]) + [1] * len( + y = y.reshape(input_shape) + stat_shape = list(input_shape[: -len(normalized_shape)]) + [1] * len( normalized_shape ) rstd = rstd.view(stat_shape) @@ -139,87 +147,66 @@ def _fused_rms_norm_backward_impl( grad_out, input, normalized_shape, rstd, weight, output_mask, ) - if _TRACE_ENABLED: - _trace_call("_fused_rms_norm_backward") - N = math.prod(normalized_shape) M = input.numel() // N - x = input.reshape(M, N).contiguous() - dout = grad_out.reshape(M, N).contiguous() - rstd_flat = rstd.reshape(M).contiguous() + x = _reshape_2d(input, M, N) + dout = _reshape_2d(grad_out, M, N) + rstd_flat = _flatten_rstd(rstd, M) - from kernelagent_oink.blackwell.rmsnorm import rmsnorm_backward - - dx, dw, _dbias, _dres = rmsnorm_backward( - x, weight, dout, rstd_flat, + w = weight if output_mask[1] else None + dx, dw, _db, _dres = _oink_rmsnorm().rmsnorm_backward( + x, w, dout, rstd_flat, dresidual_out=None, has_bias=False, has_residual=False, ) - dx = dx.reshape(input.shape) + grad_input: torch.Tensor | None = dx.reshape(input.shape) + grad_weight: torch.Tensor | None = dw + # Match native _fused_rms_norm_backward: return None for masked outputs. if not output_mask[0]: - dx = torch.zeros_like(input) - if not output_mask[1] or dw is None: - dw = torch.zeros_like(weight) if weight is not None else torch.empty(0) + grad_input = None + if not output_mask[1]: + grad_weight = None - return dx, dw + return grad_input, grad_weight # ========================================================================= # Registration # ========================================================================= -_ATEN_LIB: torch.library.Library | None = None - +_OVERRIDE_LIB: torch.library.Library | None = None -def override_all_aten_kernels() -> None: - """Override ``aten::_fused_rms_norm`` on CUDA with oink's RMSNorm. - Uses ``with_keyset=True`` (quack PR pattern) so the override receives - ``DispatchKeySet`` and can call the original kernel via ``call_boxed`` - for unsupported inputs — no Python wrapper overhead on the fallback path. - """ - global _ATEN_LIB - if _ATEN_LIB is not None: +def override_all_kernels() -> None: + """Override Aten's kernels on CUDA with Oink's kernels.""" + global _OVERRIDE_LIB + if _OVERRIDE_LIB is not None: return fwd_fallback = torch.library.get_kernel("aten::_fused_rms_norm", "CUDA") - bwd_fallback = torch.library.get_kernel( - "aten::_fused_rms_norm_backward", "CUDA" - ) + bwd_fallback = torch.library.get_kernel("aten::_fused_rms_norm_backward", "CUDA") - fwd_impl = functools.partial( - _fused_rms_norm_impl, fallback_kernel=fwd_fallback - ) - bwd_impl = functools.partial( - _fused_rms_norm_backward_impl, fallback_kernel=bwd_fallback - ) + fwd_impl = partial(_fused_rms_norm_impl, fallback_kernel=fwd_fallback) + bwd_impl = partial(_fused_rms_norm_backward_impl, fallback_kernel=bwd_fallback) lib = torch.library.Library("aten", "IMPL") lib.impl("_fused_rms_norm", fwd_impl, "CUDA", with_keyset=True) lib.impl("_fused_rms_norm_backward", bwd_impl, "CUDA", with_keyset=True) - _ATEN_LIB = lib - logger.info("Oink: overrode aten::_fused_rms_norm on CUDA (with_keyset)") + _OVERRIDE_LIB = lib + logger.info("Oink: overrode aten::_fused_rms_norm on CUDA") -def restore_all_aten_kernels() -> None: - """Remove the override and restore PyTorch's native CUDA RMSNorm.""" - global _ATEN_LIB - if _ATEN_LIB is None: +def restore_all_kernels() -> None: + """Remove the override and restore PyTorch's native CUDA kernels.""" + global _OVERRIDE_LIB + if _OVERRIDE_LIB is None: return - _ATEN_LIB = None - logger.info("Oink: restored aten::_fused_rms_norm to PyTorch default") - - -# Backward-compatible aliases. -override_aten_rmsnorm = override_all_aten_kernels -restore_aten_rmsnorm = restore_all_aten_kernels + _OVERRIDE_LIB = None __all__ = [ - "override_all_aten_kernels", - "restore_all_aten_kernels", - "override_aten_rmsnorm", - "restore_aten_rmsnorm", + "override_all_kernels", + "restore_all_kernels", ] diff --git a/oink/src/kernelagent_oink/blackwell/rmsnorm.py b/oink/src/kernelagent_oink/blackwell/rmsnorm.py index ebacbe9..ddfe527 100644 --- a/oink/src/kernelagent_oink/blackwell/rmsnorm.py +++ b/oink/src/kernelagent_oink/blackwell/rmsnorm.py @@ -3222,114 +3222,6 @@ def _rmsnorm_forward_ptr_into( ) -# --------------------------------------------------------------------------- -# Pre-compiled kernel API (quack-style _compile_rmsnorm_fwd) -# --------------------------------------------------------------------------- - -_COMPILE_FWD_CACHE: dict[tuple, object] = {} - - -def _make_fake_tensor( - dtype: type, shape: tuple, divisibility: int = 1, leading_dim: int = -1 -): - """Create a symbolic tensor for cute.compile (same as quack's fake_tensor).""" - if leading_dim < 0: - leading_dim = len(shape) + leading_dim - if dtype is None: - return None - stride = tuple( - cute.sym_int64(divisibility=divisibility) if i != leading_dim else 1 - for i in range(len(shape)) - ) - return cute.runtime.make_fake_tensor( - dtype, shape, stride=stride, - assumed_align=divisibility * dtype.width // 8, - ) - - -def _compile_rmsnorm_fwd( - dtype: type, - out_dtype: type | None, - bias_dtype: type | None, - weight_dtype: type | None, - residual_dtype: type | None, - residual_out_dtype: type | None, - N: int, - store_rstd: bool, - has_residual: bool, - has_bias: bool, -) -> object: - """Pre-compile the RMSNorm forward kernel and return a callable. - - Returns a TVM-FFI compiled callable with signature:: - - kernel(x, weight, bias, residual, out, residual_out, rstd, mean, eps) - - where each argument is a torch.Tensor (or None) and eps is a float. - DLPack tensor conversion happens at C++ level — zero Python overhead - per call. This matches the quack ``_compile_rmsnorm_fwd`` API. - """ - key = ( - dtype, out_dtype, bias_dtype, weight_dtype, - residual_dtype, residual_out_dtype, N, - store_rstd, has_residual, has_bias, - ) - cached = _COMPILE_FWD_CACHE.get(key) - if cached is not None: - return cached - - # Resolve schedule once. - direct_gmem = _direct_gmem_from_policy( - default=bool(dtype.width == 16 and N in {128, 4096, 6144, 7168, 8192}) - ) - use_async = not direct_gmem - copy_bits = _copy_bits_from_policy(default=128, can_use_256=False) - stage = 1 - - op = RMSNormSM100( - N, - dtype, - stage=stage, - copy_bits=int(copy_bits), - use_async=bool(use_async), - direct_gmem=bool(direct_gmem), - ) - - # Build symbolic tensors for compilation (same as quack's fake_tensor). - batch_sym = cute.sym_int() - all_dtypes = [dtype, out_dtype, residual_dtype, weight_dtype, - bias_dtype, residual_out_dtype] - import math as _math - div = _math.gcd(N, *(128 // dt.width for dt in all_dtypes if dt is not None)) - - x_fake = _make_fake_tensor(dtype, (batch_sym, N), div) - out_fake = _make_fake_tensor(out_dtype or dtype, (batch_sym, N), div) - res_fake = _make_fake_tensor(residual_dtype, (batch_sym, N), div) if has_residual else None - res_out_fake = _make_fake_tensor(residual_out_dtype, (batch_sym, N), div) if residual_out_dtype else None - w_fake = _make_fake_tensor(weight_dtype, (N,), div) if weight_dtype else None - b_fake = _make_fake_tensor(bias_dtype, (N,), div) if bias_dtype else None - rstd_fake = _make_fake_tensor(cutlass.Float32, (batch_sym,)) if store_rstd else None - - # Compile with TVM FFI — the returned callable accepts torch tensors - # directly via DLPack at C++ level. - compiled = cute.compile( - op, - x_fake, - w_fake, - b_fake, - res_fake, - out_fake, - res_out_fake, - rstd_fake, - cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), - Float32(0), # eps placeholder - options="--enable-tvm-ffi", - ) - - _COMPILE_FWD_CACHE[key] = compiled - return compiled - - def _fused_add_rmsnorm_forward_ptr_inplace( x: Tensor, residual: Tensor, diff --git a/oink/tests/test_aten_override.py b/oink/tests/test_aten_override.py index 6ecb635..76c20c5 100644 --- a/oink/tests/test_aten_override.py +++ b/oink/tests/test_aten_override.py @@ -12,14 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for Oink's aten operator overrides. +"""Tests for Oink's operator overrides. -Verifies that ``register_all_kernels`` / ``override_all_aten_kernels`` -properly patches PyTorch's aten ops and that the overridden kernels produce -numerically correct results compared to PyTorch's native CUDA kernels. +Verifies that ``register_all_kernels`` / ``override_all_kernels`` +properly patches Oink's kernels and their backward, and that the +overridden kernels produce numerically correct results. -The correctness tests capture the original CUDA kernel *before* the override -is applied, then compare the override's output against it. +Reference values are computed by calling the original aten CUDA kernel +captured via ``torch.library.get_kernel`` before the override is applied, +invoked with ``call_boxed(DispatchKeySet, ...)``. """ from __future__ import annotations @@ -29,65 +30,38 @@ import pytest import torch -# --------------------------------------------------------------------------- -# Skip helpers -# --------------------------------------------------------------------------- - TEST_CUDA = torch.cuda.is_available() _SM = 0 if TEST_CUDA: _major, _minor = torch.cuda.get_device_capability(0) _SM = 10 * _major + _minor - SM100_OR_LATER = _SM >= 100 requires_cuda = pytest.mark.skipif(not TEST_CUDA, reason="CUDA not available") requires_sm100 = pytest.mark.skipif(not SM100_OR_LATER, reason="requires SM100+") -# --------------------------------------------------------------------------- -# Capture original CUDA kernels *before* any override is applied. -# --------------------------------------------------------------------------- -_orig_fused_rms_norm = None -_orig_fused_rms_norm_bwd = None -_orig_native_layer_norm = None -_orig_native_layer_norm_bwd = None -_orig_softmax = None -_orig_softmax_bwd = None +_CUDA_KS = None +_orig_kernels: dict[str, object] = {} if TEST_CUDA: try: - _orig_fused_rms_norm = torch.library.get_kernel( - "aten::_fused_rms_norm", "CUDA" - ) - _orig_fused_rms_norm_bwd = torch.library.get_kernel( - "aten::_fused_rms_norm_backward", "CUDA" - ) - _orig_native_layer_norm = torch.library.get_kernel( - "aten::native_layer_norm", "CUDA" - ) - _orig_native_layer_norm_bwd = torch.library.get_kernel( - "aten::native_layer_norm_backward", "CUDA" - ) - _orig_softmax = torch.library.get_kernel("aten::_softmax", "CUDA") - _orig_softmax_bwd = torch.library.get_kernel( - "aten::_softmax_backward_data", "CUDA" - ) + _CUDA_KS = torch.DispatchKeySet(torch.DispatchKey.CUDA) + for _op in [ + "_fused_rms_norm", + "_fused_rms_norm_backward", + ]: + _orig_kernels[_op] = torch.library.get_kernel(f"aten::{_op}", "CUDA") except Exception: pass -# --------------------------------------------------------------------------- -# Apply the override (module-level, happens once). -# --------------------------------------------------------------------------- - _OVERRIDE_APPLIED = False - if TEST_CUDA and SM100_OR_LATER: try: - from kernelagent_oink.aten_override import override_all_aten_kernels + from kernelagent_oink.aten_override import override_all_kernels - override_all_aten_kernels() + override_all_kernels() _OVERRIDE_APPLIED = True except Exception: pass @@ -96,9 +70,6 @@ not _OVERRIDE_APPLIED, reason="override not applied" ) -# --------------------------------------------------------------------------- -# Constants -# --------------------------------------------------------------------------- SHAPES = [(8, 128), (4, 8, 32), (2, 16, 512), (4, 32, 1024)] DTYPES = [torch.float16, torch.bfloat16, torch.float32] @@ -106,42 +77,20 @@ def _atol_for(dtype): - if dtype in (torch.float16, torch.bfloat16): - return 1e-1 - return 1e-5 - - -# ========================================================================= -# Registration tests -# ========================================================================= + if dtype == torch.bfloat16: + return 1e-1 # bf16 has 8-bit mantissa, larger rounding error + if dtype == torch.float16: + return 1e-2 # fp16 has 11-bit mantissa + return 1e-4 # fp32 @requires_cuda @requires_sm100 def test_override_sets_library(): - """The aten Library object should be non-None after override.""" - from kernelagent_oink.aten_override import _ATEN_LIB - - assert _ATEN_LIB is not None, "override_all_aten_kernels did not create Library" + """The Library object should be non-None after override.""" + from kernelagent_oink.aten_override import _OVERRIDE_LIB - -@requires_cuda -@requires_sm100 -def test_all_fallbacks_captured(): - """All 6 fallback kernels should have been captured.""" - from kernelagent_oink.aten_override import _fallbacks - - expected_ops = [ - "_fused_rms_norm", - "_fused_rms_norm_backward", - "native_layer_norm", - "native_layer_norm_backward", - "_softmax", - "_softmax_backward_data", - ] - for op in expected_ops: - assert op in _fallbacks, f"fallback not captured for {op}" - assert _fallbacks[op] is not None, f"fallback is None for {op}" + assert _OVERRIDE_LIB is not None, "override_all_kernels did not create Library" @requires_cuda @@ -155,86 +104,32 @@ def test_custom_ops_registered(): assert hasattr(torch.ops.oink, "rmsnorm"), "torch.ops.oink.rmsnorm missing" -# ========================================================================= -# Availability / stride-guard tests (no GPU required for some) -# ========================================================================= - - def test_oink_availability_checks(monkeypatch: pytest.MonkeyPatch): - """Probe is_oink_available_for_device with mocked CUDA.""" - from kernelagent_oink.aten_override import _is_supported + """Probe _is_supported with mocked CUDA.""" + from kernelagent_oink.aten_override import _get_device_major, _is_supported - # Mock a CUDA tensor with SM90 (below threshold). fake_tensor = types.SimpleNamespace( is_cuda=True, dtype=torch.float16, device=torch.device("cuda:0") ) - # SM90 → not supported. + # SM90 (Hopper) → not supported (SM100+ only). monkeypatch.setattr(torch.cuda, "is_available", lambda: True) monkeypatch.setattr(torch.cuda, "get_device_capability", lambda d: (9, 0)) - # Clear the cached SM value. - from kernelagent_oink.aten_override import _get_device_sm - - _get_device_sm.cache_clear() + _get_device_major.cache_clear() assert _is_supported(fake_tensor) is False # SM100 → supported. monkeypatch.setattr(torch.cuda, "get_device_capability", lambda d: (10, 0)) - _get_device_sm.cache_clear() + _get_device_major.cache_clear() assert _is_supported(fake_tensor) is True - # float64 → not supported even on SM100. + # float64 → not supported fake_f64 = types.SimpleNamespace( is_cuda=True, dtype=torch.float64, device=torch.device("cuda:0") ) assert _is_supported(fake_f64) is False - _get_device_sm.cache_clear() - - -def test_can_view_as_2d_stride_guard(): - """Verify _can_view_as_2d correctly identifies non-viewable layouts.""" - from kernelagent_oink.aten_override import _can_view_as_2d - - x = torch.zeros((2, 3, 4)) - assert _can_view_as_2d(x) is True - - # Size-1 dims should be ignored by the viewability check. - base = torch.zeros((2, 10, 4)) - x_singleton = base[:, :1, :] - assert _can_view_as_2d(x_singleton) is True - - # Middle-dimension stride break: view(-1, hidden) should be invalid. - x2 = x[:, ::2, :] - with pytest.raises(RuntimeError): - x2.view(-1, x2.shape[-1]) - assert _can_view_as_2d(x2) is False - - -def test_is_oink_stride_compatible_2d(): - """Verify vectorization alignment check.""" - from kernelagent_oink.aten_override import _is_oink_stride_compatible_2d - - # Standard contiguous tensor (stride(0)==N, stride(1)==1) → compatible. - x = torch.zeros(4, 128, dtype=torch.float16) - assert _is_oink_stride_compatible_2d(x) is True - - # Padded row: stride(0) % 16 == 0 → compatible. - base = torch.zeros(4, 256, dtype=torch.float16) - x_padded = base[:, :128] # stride(0)=256, stride(1)=1 - assert x_padded.stride(0) == 256 - assert _is_oink_stride_compatible_2d(x_padded) is True - - # 1D tensor → not compatible. - assert _is_oink_stride_compatible_2d(torch.zeros(128)) is False - - # Wrong dtype → not compatible. - assert _is_oink_stride_compatible_2d(torch.zeros(4, 128, dtype=torch.float64)) is False - - -# ========================================================================= -# Correctness tests — RMSNorm -# ========================================================================= + _get_device_major.cache_clear() @requires_cuda @@ -248,177 +143,91 @@ def test_rmsnorm_fwd(dtype): x = torch.randn(*shape, dtype=dtype, device="cuda") w = torch.randn(*normalized_shape, dtype=dtype, device="cuda") + # Oink override. y, rstd = torch.ops.aten._fused_rms_norm(x, normalized_shape, w, EPS) - y_ref, rstd_ref = _orig_fused_rms_norm(x, normalized_shape, w, EPS) - torch.testing.assert_close( - y, y_ref, atol=atol, rtol=0, msg=f"fwd y shape={shape} dtype={dtype}" + # Native aten reference. + y_ref, rstd_ref = _orig_kernels["_fused_rms_norm"].call_boxed( + _CUDA_KS, x, normalized_shape, w, EPS ) - torch.testing.assert_close( - rstd, rstd_ref, atol=1e-5, rtol=0, - msg=f"fwd rstd shape={shape} dtype={dtype}", - ) - - -@requires_cuda -@requires_sm100 -@requires_override -@pytest.mark.parametrize("dtype", DTYPES) -def test_rmsnorm_bwd(dtype): - atol = 3e-1 if dtype == torch.bfloat16 else _atol_for(dtype) - for shape in SHAPES: - normalized_shape = [shape[-1]] - x = torch.randn(*shape, dtype=dtype, device="cuda") - w = torch.randn(*normalized_shape, dtype=dtype, device="cuda") - grad_out = torch.randn(*shape, dtype=dtype, device="cuda") - - x1 = x.detach().requires_grad_(True) - w1 = w.detach().requires_grad_(True) - y1, _ = torch.ops.aten._fused_rms_norm(x1, normalized_shape, w1, EPS) - y1.backward(grad_out) - - x2 = x.detach().requires_grad_(True) - w2 = w.detach().requires_grad_(True) - y2, _ = _orig_fused_rms_norm(x2, normalized_shape, w2, EPS) - y2.backward(grad_out) torch.testing.assert_close( - x1.grad, x2.grad, atol=atol, rtol=0, - msg=f"bwd x_grad shape={shape} dtype={dtype}", + y, y_ref, atol=atol, rtol=0, msg=f"fwd y shape={shape} dtype={dtype}" ) torch.testing.assert_close( - w1.grad, w2.grad, atol=atol, rtol=0, - msg=f"bwd w_grad shape={shape} dtype={dtype}", + rstd, rstd_ref, atol=_atol_for(rstd.dtype), rtol=0, + msg=f"fwd rstd shape={shape} rstd_dtype={rstd.dtype}", ) -# ========================================================================= -# Correctness tests — LayerNorm -# ========================================================================= - - @requires_cuda @requires_sm100 @requires_override @pytest.mark.parametrize("dtype", DTYPES) -def test_layernorm_fwd(dtype): +def test_rmsnorm_bwd(dtype): atol = _atol_for(dtype) for shape in SHAPES: normalized_shape = [shape[-1]] x = torch.randn(*shape, dtype=dtype, device="cuda") w = torch.randn(*normalized_shape, dtype=dtype, device="cuda") - b = torch.randn(*normalized_shape, dtype=dtype, device="cuda") + grad_out = torch.randn(*shape, dtype=dtype, device="cuda") - out, mean, rstd = torch.ops.aten.native_layer_norm( - x, normalized_shape, w, b, EPS - ) - out_ref, mean_ref, rstd_ref = _orig_native_layer_norm( - x, normalized_shape, w, b, EPS + # Get rstd from native aten forward (needed by backward). + _, rstd_ref = _orig_kernels["_fused_rms_norm"].call_boxed( + _CUDA_KS, x, normalized_shape, w, EPS ) - torch.testing.assert_close( - out, out_ref, atol=atol, rtol=0, msg=f"fwd shape={shape} dtype={dtype}" - ) - torch.testing.assert_close( - mean, mean_ref, atol=1e-5, rtol=0, - msg=f"fwd mean shape={shape} dtype={dtype}", - ) - torch.testing.assert_close( - rstd, rstd_ref, atol=1e-5, rtol=0, - msg=f"fwd rstd shape={shape} dtype={dtype}", + # Oink override backward. + dx, dw = torch.ops.aten._fused_rms_norm_backward( + grad_out, x, normalized_shape, rstd_ref, w, [True, True] ) - -@requires_cuda -@requires_sm100 -@requires_override -@pytest.mark.parametrize("dtype", DTYPES) -def test_layernorm_bwd(dtype): - atol = 3e-1 if dtype == torch.bfloat16 else _atol_for(dtype) - for shape in SHAPES: - normalized_shape = [shape[-1]] - x = torch.randn(*shape, dtype=dtype, device="cuda") - w = torch.randn(*normalized_shape, dtype=dtype, device="cuda") - b = torch.randn(*normalized_shape, dtype=dtype, device="cuda") - grad_out = torch.randn(*shape, dtype=dtype, device="cuda") - - x1 = x.detach().requires_grad_(True) - w1 = w.detach().requires_grad_(True) - b1 = b.detach().requires_grad_(True) - out1, _, _ = torch.ops.aten.native_layer_norm( - x1, normalized_shape, w1, b1, EPS + # Native aten reference backward. + dx_ref, dw_ref = _orig_kernels["_fused_rms_norm_backward"].call_boxed( + _CUDA_KS, grad_out, x, normalized_shape, rstd_ref, w, [True, True] ) - out1.backward(grad_out) - - x2 = x.detach().requires_grad_(True) - w2 = w.detach().requires_grad_(True) - b2 = b.detach().requires_grad_(True) - out2, _, _ = _orig_native_layer_norm(x2, normalized_shape, w2, b2, EPS) - out2.backward(grad_out) torch.testing.assert_close( - x1.grad, x2.grad, atol=atol, rtol=0, - msg=f"bwd x_grad shape={shape} dtype={dtype}", + dx, dx_ref, atol=atol, rtol=0, + msg=f"bwd dx shape={shape} dtype={dtype}", ) torch.testing.assert_close( - w1.grad, w2.grad, atol=atol, rtol=0, - msg=f"bwd w_grad shape={shape} dtype={dtype}", + dw, dw_ref, atol=atol, rtol=0, + msg=f"bwd dw shape={shape} dtype={dtype}", ) - torch.testing.assert_close( - b1.grad, b2.grad, atol=atol, rtol=0, - msg=f"bwd b_grad shape={shape} dtype={dtype}", - ) - - -# ========================================================================= -# Correctness tests — Softmax -# ========================================================================= @requires_cuda @requires_sm100 @requires_override -@pytest.mark.parametrize("dtype", DTYPES) -def test_softmax_fwd(dtype): - atol = _atol_for(dtype) - for shape in SHAPES: - x = torch.randn(*shape, dtype=dtype, device="cuda") - - y = torch.ops.aten._softmax(x, -1, False) - y_ref = _orig_softmax(x, -1, False) - - torch.testing.assert_close( - y, y_ref, atol=atol, rtol=0, msg=f"fwd shape={shape} dtype={dtype}" - ) - - -@requires_cuda -@requires_sm100 -@requires_override -@pytest.mark.parametrize("dtype", DTYPES) -def test_softmax_bwd(dtype): - atol = 3e-1 if dtype == torch.bfloat16 else _atol_for(dtype) - for shape in SHAPES: - x = torch.randn(*shape, dtype=dtype, device="cuda") - grad_out = torch.randn(*shape, dtype=dtype, device="cuda") - - x1 = x.detach().requires_grad_(True) - y1 = torch.softmax(x1, dim=-1) - y1.backward(grad_out) +@pytest.mark.parametrize( + "mask", [[True, True], [True, False], [False, True], [False, False]] +) +def test_backward_output_mask(mask): + """Backward output_mask behavior should match native aten exactly.""" + x = torch.randn(4, 128, dtype=torch.bfloat16, device="cuda") + w = torch.randn(128, dtype=torch.bfloat16, device="cuda") + grad = torch.randn(4, 128, dtype=torch.bfloat16, device="cuda") - x2 = x.detach().requires_grad_(True) - y2 = _orig_softmax(x2, -1, False) - dx_ref = _orig_softmax_bwd(grad_out, y2, -1, dtype) + _, rstd = torch.ops.aten._fused_rms_norm(x, [128], w, EPS) - torch.testing.assert_close( - x1.grad, dx_ref, atol=atol, rtol=0, - msg=f"bwd shape={shape} dtype={dtype}", - ) + # Oink override. + dx, dw = torch.ops.aten._fused_rms_norm_backward( + grad, x, [128], rstd, w, mask + ) + # Native aten reference. + _, rstd_ref = _orig_kernels["_fused_rms_norm"].call_boxed(_CUDA_KS, x, [128], w, EPS) + dx_ref, dw_ref = _orig_kernels["_fused_rms_norm_backward"].call_boxed( + _CUDA_KS, grad, x, [128], rstd_ref, w, mask + ) -# ========================================================================= -# Fallback tests -# ========================================================================= + assert (dx is None) == (dx_ref is None), ( + f"dx None mismatch: oink={dx is None}, aten={dx_ref is None} for mask={mask}" + ) + assert (dw is None) == (dw_ref is None), ( + f"dw None mismatch: oink={dw is None}, aten={dw_ref is None} for mask={mask}" + ) @requires_cuda @@ -435,33 +244,16 @@ def test_float64_rmsnorm_falls_back(): @requires_cuda @requires_sm100 -@requires_override -def test_float64_layernorm_falls_back(): - x = torch.randn(4, 32, dtype=torch.float64, device="cuda") - w = torch.randn(32, dtype=torch.float64, device="cuda") - b = torch.randn(32, dtype=torch.float64, device="cuda") - out, mean, rstd = torch.ops.aten.native_layer_norm(x, [32], w, b, EPS) - assert out.shape == x.shape - assert out.dtype == torch.float64 +def test_restore_then_reregister(): + """restore + re-register should work in the same process.""" + from kernelagent_oink import unregister_all_kernels + from kernelagent_oink.aten_override import override_all_kernels + unregister_all_kernels() -@requires_cuda -@requires_sm100 -@requires_override -def test_float64_softmax_falls_back(): - x = torch.randn(4, 32, dtype=torch.float64, device="cuda") - y = torch.ops.aten._softmax(x, -1, False) - assert y.shape == x.shape - assert y.dtype == torch.float64 + # After unregister, re-register should succeed. + override_all_kernels() + from kernelagent_oink.aten_override import _OVERRIDE_LIB -@requires_cuda -@requires_sm100 -@requires_override -def test_non_last_dim_softmax_falls_back(): - """Softmax on dim=0 should fall back (oink only handles last dim).""" - x = torch.randn(4, 32, dtype=torch.float16, device="cuda") - y = torch.ops.aten._softmax(x, 0, False) - assert y.shape == x.shape - y_ref = _orig_softmax(x, 0, False) - torch.testing.assert_close(y, y_ref, atol=1e-3, rtol=0) + assert _OVERRIDE_LIB is not None