diff --git a/oink/README.md b/oink/README.md index bcb1032..30d0935 100644 --- a/oink/README.md +++ b/oink/README.md @@ -24,7 +24,7 @@ Recommended env vars: ```bash export PYTORCH_ALLOC_CONF=expandable_segments:True -export CUTE_DSL_ARCH=sm_103a # GB300 / SM103 +export CUTE_DSL_ARCH=sm_103 # GB300 / SM103 on the current CuTeDSL host # export CUTE_DSL_ARCH=sm_100a # GB200/B200 / SM100 ``` @@ -88,7 +88,7 @@ Reported numbers are correctness-gated against PyTorch references before timing. Current GB300 / SM103 setup: -- NVIDIA GB300, capability `(10, 3)`, `CUTE_DSL_ARCH=sm_103a` +- NVIDIA GB300, capability `(10, 3)`, `CUTE_DSL_ARCH=sm_103` - `torch==2.11.0+cu130`, CUDA `13.0` - `nvidia-cutlass-dsl==4.4.2`, `cuda-python==13.2.0` - measured BF16 STREAM-like roof: **7.140 TB/s** @@ -113,6 +113,41 @@ Historical plots remain under `benchmarks/media/`: - `gb300_bf16_qk_norm_oink_vs_cutedsl_roofline.svg`: historical GB300 Q/K-norm harness, separate from the Quack-suite table above. +### GB300 (SM103) LayerNorm backward results + +Oink's LayerNorm backward path is self-contained in this repo. The OSS +benchmark reports Oink against ATen's native LayerNorm backward reference. + +Measured on **GB300 (SM103)** in the `cute` Conda environment with torch +`2.11.0+cu130`, CUDA `13.0`, and `CUTE_DSL_ARCH=sm_103`, using CUDA graph warm +replay (`--cuda-graph`), bf16 activations/gradients, same-dtype LayerNorm +weights, and no bias. Correctness was checked before timing against a chunked +fp32 PyTorch formula for `dx` / `dweight`; the timed `ref` column uses +`torch.ops.aten.native_layer_norm_backward.default` with the same precomputed +`mean` and `rstd`. + +The OSS Quack package installed in this environment exposes LayerNorm forward but +not a `quack.rmsnorm.layernorm_bwd` API, so the benchmark reports Quack as +unavailable and omits Quack timing columns. If a Quack build with +`layernorm_bwd` is installed, the same command will add `quack_ms` and +`Oink/Quack` columns. + +The throughput columns use a logical useful-IO model for no-bias LayerNorm +backward: read `x`, read `dout`, write `dx`, read/write `weight`/`dweight`, and +read fp32 `mean` + `rstd`. This excludes implementation-specific scratch traffic, +so the values are a useful-bandwidth roofline view rather than physical HBM bytes. +Full DSv3/DSv4 tables are in [`benchmarks/README.md`](benchmarks/README.md). + + +Reproduce with: + +```bash +env PYTHONNOUSERSITE=1 CUTE_DSL_ARCH=sm_103 PYTORCH_ALLOC_CONF=expandable_segments:True \ + conda run -n cute python -u oink/benchmarks/benchmark/benchmark_layernorm_bwd_sm100.py \ + --dtype bf16 --weight-dtype same --dsv4 --iters 80 --warmup-ms 10 --cuda-graph \ + --json /tmp/oink_layernorm_bwd_sm103_dsv4_cuda_graph_seq.json +``` + ## Links | What | Link | diff --git a/oink/benchmarks/README.md b/oink/benchmarks/README.md index 9685000..9ad35f7 100644 --- a/oink/benchmarks/README.md +++ b/oink/benchmarks/README.md @@ -18,8 +18,8 @@ Recommended env vars: ```bash export PYTORCH_ALLOC_CONF=expandable_segments:True -# GB300 / SM103: -export CUTE_DSL_ARCH=sm_103a +# GB300 / SM103 on the current CuTeDSL host: +export CUTE_DSL_ARCH=sm_103 # GB200/B200 / SM100 historical runs: # export CUTE_DSL_ARCH=sm_100a ``` @@ -42,6 +42,7 @@ conda run -n cute python -m pip install 'git+https://github.com/Dao-AILab/quack. with `hidden = 4096` so `M = batch * seq`, `N = 4096`. - **DeepSeek-V3-like (DSv3)** - RMSNorm / LayerNorm / Softmax: `M ∈ {4096, 16384, 65536}`, `N ∈ {6144, 7168, 8192}` + - LayerNorm backward's `--dsv3` suite uses `N ∈ {6144, 8192}`; use `--dsv4` for the `N = 7168` hidden-state sweep. - Cross-entropy: `M ∈ {4096, 16384, 65536}`, `N ∈ {3072, 6144, 8192, 12288}` - **DeepSeek-V4-Flash norm shapes (DSv4)** from `deepseek-ai/DeepSeek-V4-Flash/inference/model.py` - hidden-state RMSNorm / LayerNorm: `M ∈ {4096, 16384, 65536}`, `N = 7168` @@ -73,7 +74,7 @@ Current measured GB300 BF16 STREAM-like roof used in the README: Regenerate on the current machine: ```bash -conda run -n cute bash -lc 'PYTHONNOUSERSITE=1 CUTE_DSL_ARCH=sm_103a \ +conda run -n cute bash -lc 'PYTHONNOUSERSITE=1 CUTE_DSL_ARCH=sm_103 \ python benchmarks/benchmark/benchmark_hbm_roofline_sm100.py --dtype bf16 --op both --gb 1 \ --json /tmp/oink_sm103_hbm_roofline_bf16_current.json' ``` @@ -94,11 +95,11 @@ Run the full Quack-suite + DSv3 set (Oink vs Quack) and write all JSON artifacts to a timestamped directory: ```bash -conda run -n cute bash -lc 'PYTHONNOUSERSITE=1 CUTE_DSL_ARCH=sm_103a \ +conda run -n cute bash -lc 'PYTHONNOUSERSITE=1 CUTE_DSL_ARCH=sm_103 \ python benchmarks/readme/run_sm100_suite.py --dtype bf16' # Include DeepSeek-V4-Flash norm workloads: -conda run -n cute bash -lc 'PYTHONNOUSERSITE=1 CUTE_DSL_ARCH=sm_103a \ +conda run -n cute bash -lc 'PYTHONNOUSERSITE=1 CUTE_DSL_ARCH=sm_103 \ python benchmarks/readme/run_sm100_suite.py --dtype bf16 --include-dsv4 \ --out-dir /tmp/oink_sm103_suite_bf16_current' ``` @@ -162,14 +163,14 @@ outputs. ```bash # DeepSeek-V3 hidden-size sweep -PYTHONNOUSERSITE=1 CUTE_DSL_ARCH=sm_103a \ +PYTHONNOUSERSITE=1 CUTE_DSL_ARCH=sm_103 \ python benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py \ --dtype bf16 --dsv3 --iters 80 --warmup-ms 15 \ --quack-baseline kernel_inplace \ --json /tmp/oink_sm103_fused_add_rmsnorm_dsv3_bf16.json # DeepSeek-V4-Flash hidden-state sweep (N=7168) -PYTHONNOUSERSITE=1 CUTE_DSL_ARCH=sm_103a \ +PYTHONNOUSERSITE=1 CUTE_DSL_ARCH=sm_103 \ python benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py \ --dtype bf16 --dsv4 --iters 80 --warmup-ms 15 \ --quack-baseline kernel_inplace \ @@ -245,6 +246,48 @@ python benchmarks/benchmark/benchmark_layernorm_sm100.py --dtype bf16 --dsv3 --i --json /tmp/oink_layernorm_fwd_dsv3.json ``` +### LayerNorm backward + +This compares Oink against ATen's native LayerNorm backward reference and, +when the installed OSS Quack package exposes `quack.rmsnorm.layernorm_bwd`, Quack +LayerNorm backward. The benchmark validates each available backend against a +chunked fp32 PyTorch formula before timing. Current table numbers use CUDA graph +warm replay (`--cuda-graph`). The local Quack package used for these runs exposes +LayerNorm forward but not `layernorm_bwd`, so Quack timing columns are omitted. + +DSv3 CUDA-graph replay results (`N ∈ {6144,8192}`): + +| M | N | Oink ms | Oink TB/s | ATen ref ms | Oink/ref | +|---:|---:|---:|---:|---:|---:| +| 4096 | 6144 | 0.0548 | 2.7574 | 0.0777 | 1.4190x | +| 4096 | 8192 | 0.0611 | 3.2951 | 0.0970 | 1.5873x | +| 16384 | 6144 | 0.1840 | 3.2833 | 0.2794 | 1.5183x | +| 16384 | 8192 | 0.2093 | 3.8480 | 0.3387 | 1.6183x | +| 65536 | 6144 | 0.6896 | 3.5043 | 1.0652 | 1.5447x | +| 65536 | 8192 | 0.7372 | 4.3705 | 1.3138 | 1.7823x | + +DSv4 hidden LayerNorm CUDA-graph replay results (`N = 7168`): + +| M | N | Oink ms | Oink TB/s | ATen ref ms | Oink/ref | +|---:|---:|---:|---:|---:|---:| +| 4096 | 7168 | 0.0591 | 2.9800 | 0.0858 | 1.4503x | +| 16384 | 7168 | 0.1990 | 3.5425 | 0.3077 | 1.5467x | +| 65536 | 7168 | 0.7467 | 3.7753 | 1.1711 | 1.5684x | + +```bash +# DeepSeek-V4-Flash hidden LayerNorm shape sweep (N=7168) +env PYTHONNOUSERSITE=1 CUTE_DSL_ARCH=sm_103 PYTORCH_ALLOC_CONF=expandable_segments:True \ + conda run -n cute python -u benchmarks/benchmark/benchmark_layernorm_bwd_sm100.py \ + --dtype bf16 --weight-dtype same --dsv4 --iters 80 --warmup-ms 10 --cuda-graph \ + --json /tmp/oink_layernorm_bwd_sm103_dsv4_cuda_graph_seq.json + +# DeepSeek-V3 shape sweep (N in {6144,8192}) +env PYTHONNOUSERSITE=1 CUTE_DSL_ARCH=sm_103 PYTORCH_ALLOC_CONF=expandable_segments:True \ + conda run -n cute python -u benchmarks/benchmark/benchmark_layernorm_bwd_sm100.py \ + --dtype bf16 --weight-dtype same --dsv3 --iters 80 --warmup-ms 10 --cuda-graph \ + --json /tmp/oink_layernorm_bwd_sm103_dsv3_cuda_graph_seq.json +``` + ## Notes - These scripts intentionally avoid importing any external Oink checkout so the diff --git a/oink/benchmarks/benchmark/bench_utils.py b/oink/benchmarks/benchmark/bench_utils.py index 6ba7a3f..bea7962 100644 --- a/oink/benchmarks/benchmark/bench_utils.py +++ b/oink/benchmarks/benchmark/bench_utils.py @@ -75,7 +75,9 @@ def ensure_blackwell_arch_env(device: Optional[torch.device] = None) -> str: Benchmarks often run outside the Oink/vLLM plugin path, so they don't benefit from the plugin's device-capability-based `CUTE_DSL_ARCH` setup. - On GB300 we want `sm_103a` instead of the older hard-coded `sm_100a`. + On this GB300/CuTeDSL 4.4.2 host, LayerNorm backward compiles reliably + with `sm_103`; callers may still pin an `a` arch explicitly if their local + CuTeDSL build requires it. """ pinned = os.environ.get("CUTE_DSL_ARCH") if pinned: @@ -86,7 +88,9 @@ def ensure_blackwell_arch_env(device: Optional[torch.device] = None) -> str: if device is None: device = torch.device("cuda") major, minor = torch.cuda.get_device_capability(device) - if int(major) == 10: + if int(major) == 10 and int(minor) == 3: + arch = "sm_103" + elif int(major) == 10: arch = f"sm_{int(major)}{int(minor)}a" os.environ["CUTE_DSL_ARCH"] = arch return arch @@ -115,6 +119,13 @@ def do_bench_triton( return float(triton_do_bench(fn, warmup=warmup_ms, rep=rep_ms, return_mode="mean")) +def do_bench_cuda_graph(fn: Callable[[], Any], *, rep_ms: int = 100) -> float: + """CUDA-graph replay timing via Triton's cudagraph benchmark helper.""" + from triton.testing import do_bench_cudagraph + + return float(do_bench_cudagraph(fn, rep=rep_ms, return_mode="mean")) + + def parse_dtype(s: str) -> torch.dtype: s = s.lower() if s == "bf16": diff --git a/oink/benchmarks/benchmark/benchmark_layernorm_bwd_sm100.py b/oink/benchmarks/benchmark/benchmark_layernorm_bwd_sm100.py new file mode 100644 index 0000000..756573c --- /dev/null +++ b/oink/benchmarks/benchmark/benchmark_layernorm_bwd_sm100.py @@ -0,0 +1,625 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import argparse +import importlib +import os +import sys +from dataclasses import dataclass +from types import ModuleType +from typing import Any, Callable, Dict, List, Optional, Tuple + +import torch +from bench_utils import ( + ErrorStatsAccumulator, + collect_device_meta, + detect_hbm_peak_gbps, + do_bench_cuda_graph, + do_bench_triton, + dsv4_hidden_norm_configs, + ensure_blackwell_arch_env, + ensure_oink_src_on_path, + error_stats_to_row, + iter_row_blocks, + parse_configs, + parse_dtype, + quack_suite_configs, + write_csv, + write_json, +) + +# Reduce fragmentation pressure on busy GPUs. +os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True") +ensure_blackwell_arch_env() +ensure_oink_src_on_path() + +_oink_ln: ModuleType | None = None +quack_layernorm_bwd: Callable[..., Any] | None = None +_QUACK_LAYERNORM_BWD_STATUS = "uninitialized" + + +def _load_optional_quack_layernorm_bwd() -> None: + global quack_layernorm_bwd, _QUACK_LAYERNORM_BWD_STATUS + try: + module = importlib.import_module("quack.rmsnorm") + quack_layernorm_bwd = getattr(module, "layernorm_bwd") + _QUACK_LAYERNORM_BWD_STATUS = "available: quack.rmsnorm.layernorm_bwd" + except Exception as e: + quack_layernorm_bwd = None + _QUACK_LAYERNORM_BWD_STATUS = f"unavailable: {type(e).__name__}: {e}" + + +_load_optional_quack_layernorm_bwd() + + +def _get_oink_layernorm() -> ModuleType: + global _oink_ln + if _oink_ln is None: + _oink_ln = importlib.import_module("kernelagent_oink.blackwell.layernorm") + return _oink_ln + + +_VERIFY_TOL_DX = { + # Match the existing Oink backward benchmark tolerance style. + torch.float32: dict(atol=1e-4, rtol=1e-3), + torch.float16: dict(atol=1e-2, rtol=1e-3), + torch.bfloat16: dict(atol=1e-1, rtol=1e-2), +} + + +@dataclass(frozen=True) +class BenchResult: + ms: float + gbps: float + + @property + def tbps(self) -> float: + return self.gbps / 1000.0 + + +BackendFn = Callable[ + [], Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]] +] + + +def dsv3_configs() -> List[Tuple[int, int]]: + # CuteDSL kernel workflow default for RMSNorm/LayerNorm when shapes are unspecified. + Ms = [4096, 16384, 65536] + Ns = [6144, 8192] + return [(m, n) for m in Ms for n in Ns] + + +def _call_quack_layernorm_bwd( + dout: torch.Tensor, + x: torch.Tensor, + weight: torch.Tensor, + mean: torch.Tensor, + rstd: torch.Tensor, + bias: Optional[torch.Tensor], +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + if quack_layernorm_bwd is None: + raise RuntimeError(_QUACK_LAYERNORM_BWD_STATUS) + if bias is not None: + raise RuntimeError("Quack LayerNorm backward bias path is not benchmarked") + + # Quack has changed public naming across releases. Prefer the common + # cute-kernels-style positional API, then try keyword spellings used by + # adjacent norm wrappers. + call_errors: list[str] = [] + for args, kwargs in ( + ((dout, x, weight, mean, rstd), {}), + ((), {"dout": dout, "x": x, "weight": weight, "mean": mean, "rstd": rstd}), + ((), {"dy": dout, "x": x, "w": weight, "mean": mean, "rstd": rstd}), + ): + try: + out = quack_layernorm_bwd(*args, **kwargs) + break + except TypeError as e: + call_errors.append(str(e)) + else: + raise TypeError( + "Unable to call Quack LayerNorm backward: " + " | ".join(call_errors) + ) + + if not isinstance(out, tuple): + raise TypeError( + f"Expected Quack LayerNorm backward to return a tuple, got {type(out)}" + ) + if len(out) == 2: + dx, dw = out + db = None + elif len(out) >= 3: + dx, dw, db = out[:3] + else: + raise TypeError( + f"Expected Quack LayerNorm backward tuple with >=2 values, got {len(out)}" + ) + return dx, dw, None if bias is None else db + + +def parse_weight_dtype(arg: str, activation_dtype: torch.dtype) -> torch.dtype: + if arg == "same": + return activation_dtype + return parse_dtype(arg) + + +def bytes_io_model_layernorm_bwd( + M: int, + N: int, + dtype: torch.dtype, + *, + weight_dtype: torch.dtype, + has_bias: bool, +) -> int: + """Useful logical IO model for LayerNorm backward. + + The model intentionally excludes implementation-specific partial-gradient + scratch traffic so Oink and PyTorch can be compared on the same useful + read/write work. + """ + elem = torch.tensor(0, dtype=dtype).element_size() + w_elem = torch.tensor(0, dtype=weight_dtype).element_size() + + # Read x + dout, write dx. + total = 3 * M * N * elem + # Read gamma, write dgamma. + total += 2 * N * w_elem + # Read mean + rstd (fp32 per row). + total += 2 * M * 4 + if has_bias: + # Write dbias. Bias reads are not needed for LayerNorm backward. + total += N * w_elem + return int(total) + + +def _compute_stats(x: torch.Tensor, eps: float) -> Tuple[torch.Tensor, torch.Tensor]: + xf = x.float() + mean = xf.mean(dim=-1).to(torch.float32) + var = ((xf - mean.unsqueeze(1)) ** 2).mean(dim=-1) + rstd = torch.rsqrt(var + eps).to(torch.float32) + return mean, rstd + + +def _call_oink( + dout: torch.Tensor, + x: torch.Tensor, + weight: torch.Tensor, + mean: torch.Tensor, + rstd: torch.Tensor, + bias: Optional[torch.Tensor], +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + return _get_oink_layernorm().layernorm_backward( + dout, x, weight, rstd, mean, bias=bias + ) + + +def _call_ref( + dout: torch.Tensor, + x: torch.Tensor, + weight: torch.Tensor, + mean: torch.Tensor, + rstd: torch.Tensor, + bias: Optional[torch.Tensor], +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + # Native ATen backward is the fastest available PyTorch reference that can + # reuse the same precomputed mean/rstd as the Oink/cute backends. + return torch.ops.aten.native_layer_norm_backward.default( + dout, + x, + [int(x.shape[-1])], + mean, + rstd, + weight, + bias, + [True, True, bias is not None], + ) + + +def _available_backend_fns( + dout: torch.Tensor, + x: torch.Tensor, + weight: torch.Tensor, + mean: torch.Tensor, + rstd: torch.Tensor, + bias: Optional[torch.Tensor], +) -> Dict[str, BackendFn]: + fns: Dict[str, BackendFn] = { + "ours": lambda: _call_oink(dout, x, weight, mean, rstd, bias), + "ref": lambda: _call_ref(dout, x, weight, mean, rstd, bias), + } + if quack_layernorm_bwd is not None and bias is None: + fns["quack"] = lambda: _call_quack_layernorm_bwd( + dout, x, weight, mean, rstd, bias + ) + return fns + + +def _dweight_tolerance( + dtype: torch.dtype, dw_ref: torch.Tensor +) -> Optional[Dict[str, float]]: + if dtype == torch.float32: + return dict(atol=2e-3, rtol=1e-3) + dw_ref_f32 = dw_ref.to(torch.float32) + scale = float(dw_ref_f32.abs().max().item()) + atol = max(2.0 * torch.finfo(dtype).eps * scale, 1e-3) + return dict(atol=float(atol), rtol=1e-3) + + +def _unpack_backend_output( + out: Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]], +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + if not isinstance(out, tuple) or len(out) < 2: + raise TypeError( + f"Expected a tuple containing at least dx and dweight, got {type(out)}" + ) + dx = out[0] + dw = out[1] + db = out[2] if len(out) > 2 else None + return dx, dw, db + + +def _verify_parity( + *, + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + dout: torch.Tensor, + mean: torch.Tensor, + rstd: torch.Tensor, + backend_fns: Dict[str, BackendFn], +) -> Dict[str, object]: + tol_dx = _VERIFY_TOL_DX[x.dtype] + M, N = int(x.shape[0]), int(x.shape[1]) + ref_block_rows = 1024 + + outputs: Dict[ + str, Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]] + ] = {} + with torch.no_grad(): + for name, fn in backend_fns.items(): + outputs[name] = _unpack_backend_output(fn()) + torch.cuda.synchronize() + + dx_accs = { + name: ErrorStatsAccumulator(total_elems=M * N) for name in outputs.keys() + } + dw_accum = torch.zeros((N,), device=x.device, dtype=torch.float32) + db_accum = ( + torch.zeros((N,), device=x.device, dtype=torch.float32) + if bias is not None + else None + ) + weight_f32 = weight.float() + + for start, end in iter_row_blocks(M, ref_block_rows): + x_f32 = x[start:end].float() + dout_f32 = dout[start:end].float() + mean_blk = mean[start:end].unsqueeze(1) + rstd_blk = rstd[start:end].unsqueeze(1) + x_hat = (x_f32 - mean_blk) * rstd_blk + wdy = dout_f32 * weight_f32 + mean_wdy = wdy.mean(dim=-1, keepdim=True) + mean_xhat_wdy = (x_hat * wdy).mean(dim=-1, keepdim=True) + dx_ref = ((wdy - mean_wdy - x_hat * mean_xhat_wdy) * rstd_blk).to(x.dtype) + + for name, (dx, _, _) in outputs.items(): + torch.testing.assert_close( + dx[start:end], + dx_ref, + **tol_dx, + msg=f"{name} dx mismatch M={M} N={N} rows={start}:{end}", + ) + dx_accs[name].update(dx[start:end], dx_ref) + + dw_accum += (dout_f32 * x_hat).sum(dim=0) + if db_accum is not None: + db_accum += dout_f32.sum(dim=0) + + stats: Dict[str, object] = {} + for name, acc in dx_accs.items(): + stats.update(error_stats_to_row(f"{name}_err_dx", acc.finalize())) + + dw_ref = dw_accum.to(weight.dtype) + dw_tol = _dweight_tolerance(weight.dtype, dw_ref) + for name, (_, dw, _) in outputs.items(): + if dw is None: + raise AssertionError(f"{name} did not return dweight for M={M} N={N}") + torch.testing.assert_close( + dw, + dw_ref, + **dw_tol, + msg=f"{name} dweight mismatch M={M} N={N}", + ) + dw_acc = ErrorStatsAccumulator( + total_elems=int(dw_ref.numel()), p99_target_samples=int(dw_ref.numel()) + ) + dw_acc.update(dw, dw_ref) + stats.update(error_stats_to_row(f"{name}_err_dw", dw_acc.finalize())) + + if bias is not None: + assert db_accum is not None + db_ref = db_accum.to(bias.dtype) + db_tol = _dweight_tolerance(bias.dtype, db_ref) + for name, (_, _, db) in outputs.items(): + if db is None: + raise AssertionError(f"{name} did not return dbias for M={M} N={N}") + torch.testing.assert_close( + db, + db_ref, + **db_tol, + msg=f"{name} dbias mismatch M={M} N={N}", + ) + db_acc = ErrorStatsAccumulator( + total_elems=int(db_ref.numel()), p99_target_samples=int(db_ref.numel()) + ) + db_acc.update(db, db_ref) + stats.update(error_stats_to_row(f"{name}_err_db", db_acc.finalize())) + + return stats + + +def bench_single( + M: int, + N: int, + dtype: torch.dtype, + weight_dtype: torch.dtype, + *, + eps: float, + warmup_ms: int, + iters_ms: int, + verify: bool, + has_bias: bool, + cuda_graph: bool, +) -> Tuple[Dict[str, BenchResult], Dict[str, object]]: + device = torch.device("cuda") + x = torch.randn(M, N, device=device, dtype=dtype) + weight = torch.randn(N, device=device, dtype=weight_dtype) + bias = torch.randn(N, device=device, dtype=weight_dtype) if has_bias else None + dout = torch.randn(M, N, device=device, dtype=dtype) + mean, rstd = _compute_stats(x, eps) + + backend_fns = _available_backend_fns(dout, x, weight, mean, rstd, bias) + stats: Dict[str, object] = {} + if verify: + stats = _verify_parity( + x=x, + weight=weight, + bias=bias, + dout=dout, + mean=mean, + rstd=rstd, + backend_fns=backend_fns, + ) + + bytes_io = bytes_io_model_layernorm_bwd( + M, N, dtype, weight_dtype=weight_dtype, has_bias=has_bias + ) + results: Dict[str, BenchResult] = {} + for name, fn in backend_fns.items(): + if cuda_graph: + # Warm outside graph so CuTeDSL compile/cache and workspace allocation are + # not captured in the measured replay. Keep the capture-time outputs + # alive until after replay timing; this avoids tearing down graph-owned + # allocations while the captured graph is still being measured. + graph_outputs: list[object] = [None] + + def graph_fn() -> object: + graph_outputs[0] = fn() + return graph_outputs[0] + + graph_fn() + torch.cuda.synchronize() + ms = do_bench_cuda_graph(graph_fn, rep_ms=iters_ms) + else: + ms = do_bench_triton(fn, warmup_ms=warmup_ms, rep_ms=iters_ms) + gbps = bytes_io / (ms * 1e-3) / 1e9 + results[name] = BenchResult(ms=ms, gbps=gbps) + + return results, stats + + +def _add_backend_result(row: Dict[str, object], name: str, result: BenchResult) -> None: + prefix = "ours" if name == "ours" else name + row[f"{prefix}_ms"] = result.ms + row[f"{prefix}_gbps"] = result.gbps + row[f"{prefix}_tbps"] = result.tbps + + +def _append_speedups(row: Dict[str, object], results: Dict[str, BenchResult]) -> None: + ours = results.get("ours") + if ours is None: + return + for name, result in results.items(): + if name == "ours": + continue + row[f"speedup_vs_{name}"] = result.ms / ours.ms + + +def _print_summary(rows: List[Dict[str, object]]) -> None: + base_headers = ["M", "N", "dtype", "weight_dtype", "ours_ms", "ours_tbps"] + optional_headers = [ + "ref_ms", + "ref_tbps", + "speedup_vs_ref", + "quack_ms", + "quack_tbps", + "speedup_vs_quack", + ] + headers = base_headers + [h for h in optional_headers if any(h in r for r in rows)] + + print("\nSummary:") + print(" ".join(h.rjust(22) for h in headers)) + for row in rows: + parts: List[str] = [] + for header in headers: + value = row.get(header) + if isinstance(value, float): + parts.append(f"{value:22.4f}") + else: + parts.append(f"{str(value):>22}") + print(" ".join(parts)) + + +def main() -> None: + if not torch.cuda.is_available(): + raise SystemExit("CUDA not available") + + torch.cuda.set_device(0) + device = torch.device("cuda") + props = torch.cuda.get_device_properties(device) + sm = props.major * 10 + props.minor + print(f"Running on {torch.cuda.get_device_name(device)} (SM{sm})") + print(f"Quack LayerNorm backward: {_QUACK_LAYERNORM_BWD_STATUS}") + + parser = argparse.ArgumentParser() + parser.add_argument( + "--dtype", type=str, default="bf16", choices=["fp16", "bf16", "fp32"] + ) + parser.add_argument( + "--weight-dtype", + type=str, + default="same", + choices=["same", "fp16", "bf16", "fp32"], + help="LayerNorm weight dtype. `same` matches activation dtype.", + ) + parser.add_argument("--eps", type=float, default=1e-6) + parser.add_argument("--with-bias", action="store_true") + parser.add_argument( + "--iters", type=int, default=100, help="Triton do_bench rep_ms (kernel-only)." + ) + parser.add_argument("--warmup-ms", type=int, default=25) + parser.add_argument( + "--csv", type=str, default=None, help="Optional CSV output path; appends rows" + ) + parser.add_argument( + "--json", type=str, default=None, help="Optional JSON output path (meta + rows)" + ) + parser.add_argument("--configs", type=str, default="1024x4096,8192x4096") + parser.add_argument( + "--quack-suite", action="store_true", help="Run Quack-style batch/seq grid" + ) + parser.add_argument( + "--dsv3", + action="store_true", + help="Run DSv3 set: M in {4096,16384,65536}, N in {6144,8192}", + ) + parser.add_argument( + "--dsv4", + action="store_true", + help="Run DeepSeek-V4-Flash hidden LayerNorm set: M in {4096,16384,65536}, N=7168", + ) + parser.add_argument( + "--skip-verify", + action="store_true", + help="Skip correctness checks before timing", + ) + parser.add_argument( + "--cuda-graph", + action="store_true", + help="Time warm CUDA-graph replay instead of eager do_bench calls.", + ) + args = parser.parse_args() + + dtype = parse_dtype(args.dtype) + weight_dtype = parse_weight_dtype(args.weight_dtype, dtype) + eps = float(args.eps) + + if args.quack_suite: + cfgs = [(bs * sl, hidden) for (bs, sl, hidden) in quack_suite_configs()] + elif args.dsv3: + cfgs = dsv3_configs() + elif args.dsv4: + cfgs = dsv4_hidden_norm_configs() + else: + cfgs = parse_configs(args.configs) + + hbm_peak = detect_hbm_peak_gbps(device) + rows_out: List[Dict[str, object]] = [] + for M, N in cfgs: + print( + f"bench M={M:<8d} N={N:<6d} dtype={args.dtype} " + f"weight_dtype={args.weight_dtype} ...", + flush=True, + ) + results, stats = bench_single( + M=M, + N=N, + dtype=dtype, + weight_dtype=weight_dtype, + eps=eps, + warmup_ms=int(args.warmup_ms), + iters_ms=int(args.iters), + verify=not args.skip_verify, + has_bias=bool(args.with_bias), + cuda_graph=bool(args.cuda_graph), + ) + + row: Dict[str, object] = { + "M": M, + "N": N, + "dtype": args.dtype, + "weight_dtype": args.weight_dtype, + "eps": eps, + "with_bias": bool(args.with_bias), + "bytes_io": bytes_io_model_layernorm_bwd( + M, N, dtype, weight_dtype=weight_dtype, has_bias=bool(args.with_bias) + ), + } + for name, result in results.items(): + _add_backend_result(row, name, result) + if "ours" in results: + row["ours_hbm_frac"] = results["ours"].gbps / hbm_peak + _append_speedups(row, results) + row.update(stats) + rows_out.append(row) + + if args.csv is not None: + write_csv(args.csv, rows_out) + if args.json is not None: + meta = collect_device_meta(device) + write_json( + args.json, + meta, + rows_out, + extra={ + "method": ( + "triton.testing.do_bench_cudagraph(mean)" + if args.cuda_graph + else "triton.testing.do_bench(mean)" + ), + "cuda_graph": bool(args.cuda_graph), + "warmup_ms": int(args.warmup_ms), + "rep_ms": int(args.iters), + "io_model_bytes": "see bytes_io_model_layernorm_bwd in script", + "quack_layernorm_bwd_status": _QUACK_LAYERNORM_BWD_STATUS, + "reference_backend": "torch.ops.aten.native_layer_norm_backward.default", + }, + ) + + _print_summary(rows_out) + + if args.cuda_graph: + # Some torch/CUDAGraph allocator combinations can segfault during Python + # finalization after captured allocation-heavy benchmark functions have + # already written valid results. Exit directly after flushing benchmark + # output so graph replay CLI runs return success deterministically. + sys.stdout.flush() + sys.stderr.flush() + os._exit(0) + + +if __name__ == "__main__": + main() diff --git a/oink/benchmarks/readme/summarize_results.py b/oink/benchmarks/readme/summarize_results.py index efc92e4..607f04c 100644 --- a/oink/benchmarks/readme/summarize_results.py +++ b/oink/benchmarks/readme/summarize_results.py @@ -65,6 +65,14 @@ def _pick_columns(rows: Sequence[Dict[str, Any]]) -> List[str]: "ours_ms", "ours_tbps", "ours_hbm_frac", + "ref_ms", + "ref_tbps", + "speedup_vs_ref", + # Historical/external JSON compatibility; current OSS Oink benchmarks do + # not import a cute-kernels baseline directly. + "cute_kernels_ms", + "cute_kernels_tbps", + "speedup_vs_cute_kernels", "quack_ms", "quack_tbps", "speedup_vs_quack", @@ -168,7 +176,8 @@ def summarize_one(path: str) -> str: ts = meta.get("timestamp") parts.append("") parts.append( - f"- device: `{device}` | capability: `{cap}` | torch: `{torch_ver}` | cuda: `{cuda_ver}` | git_sha: `{git_sha}` | timestamp: `{ts}`" + f"- device: `{device}` | capability: `{cap}` | torch: `{torch_ver}` | " + f"cuda: `{cuda_ver}` | git_sha: `{git_sha}` | timestamp: `{ts}`" ) method = meta.get("method") if method is not None: @@ -182,13 +191,23 @@ def summarize_one(path: str) -> str: parts.append("") parts.append(_md_table(rows, cols)) - speeds = [float(r["speedup_vs_quack"]) for r in rows if "speedup_vs_quack" in r] - gm = _geomean(speeds) - if gm is not None: - parts.append("") - parts.append( - f"- geomean speedup vs Quack: `{gm:.3f}x` (over {len(speeds)} shapes)" - ) + speedup_cols = sorted( + { + k + for r in rows + for k in r.keys() + if isinstance(k, str) and k.startswith("speedup_vs_") + } + ) + for col in speedup_cols: + speeds = [float(r[col]) for r in rows if col in r] + gm = _geomean(speeds) + if gm is not None: + baseline = col.removeprefix("speedup_vs_").replace("_", " ") + parts.append("") + parts.append( + f"- geomean speedup vs {baseline}: `{gm:.3f}x` (over {len(speeds)} shapes)" + ) err_block = _summarize_error_stats(rows) if err_block: diff --git a/oink/src/kernelagent_oink/blackwell/layernorm.py b/oink/src/kernelagent_oink/blackwell/layernorm.py index 6b4b9c7..8a24a15 100644 --- a/oink/src/kernelagent_oink/blackwell/layernorm.py +++ b/oink/src/kernelagent_oink/blackwell/layernorm.py @@ -14,7 +14,7 @@ # limitations under the License. """ -LayerNorm kernel for SM100 (Blackwell) in CuteDSL. +LayerNorm kernels for Blackwell SM10x in CuteDSL. This implementation: - Mirrors Quack's LayerNorm tiling / cluster policy / cp.async pipeline @@ -24,15 +24,17 @@ - Optionally writes out per-row `rstd` and `mean` buffers for reuse in backward or fused kernels. -Backward is implemented with dedicated CuteDSL kernels for input and -parameter gradients (dx, dweight, dbias), avoiding PyTorch autograd -while matching `torch.nn.functional.layer_norm`'s gradients numerically. +Backward is self-contained in this repo. Validated GB300/SM103 DSv3/DSv4 +hidden sizes use the fused pointer fast path; unsupported shapes/layouts use +local fallback kernels for dx and parameter gradients while matching +`torch.nn.functional.layer_norm`'s gradients numerically. """ from __future__ import annotations import math import operator +from functools import partial from typing import Optional, Tuple, Type import torch @@ -60,11 +62,18 @@ from kernelagent_oink.blackwell.lite_quack import ( _KERNEL_ACCEPTS_LAYOUT_ARGS, TORCH2CUTE_DTYPE, + RMSNormBackward as _LiteRMSNormBackward, ReductionBase as _ReductionBase, + atomic_add_tensor_f32, convert_from_dlpack as convert_from_dlpack_cute, + coord_offset_i64, + copy as _quack_copy, + fill_oob, + get_copy_atom, get_sm_count, predicate_k, row_reduce, + row_reduce_add, warp_reduce, ) from kernelagent_oink.blackwell.fast_launch import ( @@ -77,12 +86,203 @@ ) # Simple compile cache for the forward kernel -_COMPILE_CACHE: dict[Tuple[int, type[cutlass.Numeric], bool, bool, bool], object] = {} +_COMPILE_CACHE: dict[Tuple[object, ...], object] = {} _PTR_COMPILE_CACHE: dict[Tuple[object, ...], object] = {} # Backward compile caches: one for dx, one for parameter gradients. -_BWD_DX_COMPILE_CACHE: dict[Tuple[int, Type[cutlass.Numeric]], object] = {} -_BWD_PARAM_COMPILE_CACHE: dict[Tuple[int, Type[cutlass.Numeric], bool], object] = {} +_BWD_DX_COMPILE_CACHE: dict[ + Tuple[int, Type[cutlass.Numeric], Type[cutlass.Numeric], Type[cutlass.Numeric]], + object, +] = {} +_BWD_PARAM_COMPILE_CACHE: dict[ + Tuple[int, Type[cutlass.Numeric], Type[cutlass.Numeric], bool], object +] = {} +_BWD_PTR_COMPILE_CACHE: dict[Tuple[object, ...], object] = {} +_BWD_WORKSPACE_CACHE: dict[ + Tuple[int, int, int, int, bool], Tuple[Tensor, Optional[Tensor]] +] = {} +_BWD_COMBINED_BIAS_WORKSPACE_CACHE: dict[ + Tuple[int, int, int, int], Tuple[Tensor, Tensor] +] = {} +_BWD_COMBINED_PAIR_CACHE: dict[Tuple[int, int], Tuple[Tensor, Tensor]] = {} +_BWD_ATOMIC_WORKSPACE_CACHE: dict[Tuple[int, int, int], Tensor] = {} +_BWD_REDUCTION_STREAM_CACHE: dict[int, Tuple[torch.cuda.Stream, torch.cuda.Stream]] = {} + + +def _reduce_partial_sum_fp32(partial: Tensor, *, device_index: int) -> Tensor: + """Reduce a (sm_count, N) fp32 partial buffer into an (N,) fp32 result.""" + assert partial.dtype is torch.float32 + assert partial.dim() == 2 + # On GB300, the generic reduction kernel used by `sum(dim=0)` is faster for + # these LayerNorm partial buffers than routing the reduction through GEMM. + _ = device_index # kept for call-site compatibility / future tuning. + return partial.sum(dim=0) + + +def _get_layernorm_bwd_workspace( + *, + device_index: int, + stream_handle: int, + sm_count: int, + N: int, + has_bias: bool, +) -> Tuple[Tensor, Optional[Tensor]]: + key = ( + int(device_index), + int(stream_handle), + int(sm_count), + int(N), + bool(has_bias), + ) + cached = _BWD_WORKSPACE_CACHE.get(key) + if cached is not None: + if has_bias: + dw_partial, db_partial = cached + assert db_partial is not None + pair_key = (int(dw_partial.data_ptr()), int(db_partial.data_ptr())) + if pair_key not in _BWD_COMBINED_PAIR_CACHE: + combined_key = ( + int(device_index), + int(stream_handle), + int(sm_count), + int(N), + ) + combined_cached = _BWD_COMBINED_BIAS_WORKSPACE_CACHE.get(combined_key) + if combined_cached is not None: + _BWD_COMBINED_PAIR_CACHE[pair_key] = combined_cached + return cached + + device = torch.device("cuda", device_index) + if has_bias: + combined_partial = torch.empty( + (2, sm_count, N), device=device, dtype=torch.float32 + ) + dw_partial = combined_partial[0] + db_partial = combined_partial[1] + reduced_pair = torch.empty((2, N), device=device, dtype=torch.float32) + _BWD_COMBINED_BIAS_WORKSPACE_CACHE[ + (int(device_index), int(stream_handle), int(sm_count), int(N)) + ] = (combined_partial, reduced_pair) + _BWD_COMBINED_PAIR_CACHE[ + (int(dw_partial.data_ptr()), int(db_partial.data_ptr())) + ] = (combined_partial, reduced_pair) + else: + dw_partial = torch.empty(sm_count, N, device=device, dtype=torch.float32) + db_partial = None + cached = (dw_partial, db_partial) + _BWD_WORKSPACE_CACHE[key] = cached + return cached + + +def _get_layernorm_bwd_reduction_streams( + device: torch.device, +) -> Tuple[torch.cuda.Stream, torch.cuda.Stream]: + device_index = ( + device.index if device.index is not None else torch.cuda.current_device() + ) + cached = _BWD_REDUCTION_STREAM_CACHE.get(int(device_index)) + if cached is not None: + return cached + streams = ( + torch.cuda.Stream(device=device), + torch.cuda.Stream(device=device), + ) + _BWD_REDUCTION_STREAM_CACHE[int(device_index)] = streams + return streams + + +def _get_layernorm_bwd_atomic_dw_workspace( + *, + device_index: int, + stream_handle: int, + N: int, +) -> Tensor: + key = (int(device_index), int(stream_handle), int(N)) + cached = _BWD_ATOMIC_WORKSPACE_CACHE.get(key) + if cached is not None: + return cached + + dw_acc = torch.empty( + N, device=torch.device("cuda", device_index), dtype=torch.float32 + ) + _BWD_ATOMIC_WORKSPACE_CACHE[key] = dw_acc + return dw_acc + + +def _finalize_layernorm_bwd_partials( + *, + dw_partial: Tensor, + db_partial: Optional[Tensor], + weight: Tensor, + bias: Optional[Tensor], + device: torch.device, +) -> Tuple[Tensor, Optional[Tensor]]: + if db_partial is not None: + pair_cached = _BWD_COMBINED_PAIR_CACHE.get( + (int(dw_partial.data_ptr()), int(db_partial.data_ptr())) + ) + if pair_cached is not None: + combined_partial, reduced_pair = pair_cached + torch.sum(combined_partial, dim=1, out=reduced_pair) + dweight_fp32 = reduced_pair[0] + dbias_fp32 = reduced_pair[1] + dweight = ( + dweight_fp32.clone() + if weight.dtype == torch.float32 + else dweight_fp32.to(weight.dtype) + ) + assert bias is not None + dbias = ( + dbias_fp32.clone() + if bias.dtype == torch.float32 + else dbias_fp32.to(bias.dtype) + ) + return dweight, dbias + + if db_partial is None: + # No-bias is the hot DSv3 path. Avoid auxiliary stream handshakes here; + # the reduction is small and launch latency dominates for M=4096. + dweight_fp32 = _reduce_partial_sum_fp32( + dw_partial, device_index=weight.get_device() + ) + dweight = ( + dweight_fp32 + if weight.dtype == torch.float32 + else dweight_fp32.to(weight.dtype) + ) + return dweight, None + + dw_stream, db_stream = _get_layernorm_bwd_reduction_streams(device) + current_stream = torch.cuda.current_stream(device=device) + + dw_stream.wait_stream(current_stream) + db_stream.wait_stream(current_stream) + + with torch.cuda.stream(dw_stream): + dweight_fp32 = _reduce_partial_sum_fp32( + dw_partial, device_index=weight.get_device() + ) + dweight = ( + dweight_fp32 + if weight.dtype == torch.float32 + else dweight_fp32.to(weight.dtype) + ) + + assert bias is not None + with torch.cuda.stream(db_stream): + dbias_fp32 = _reduce_partial_sum_fp32( + db_partial, device_index=bias.get_device() + ) + dbias = dbias_fp32 if bias.dtype == torch.float32 else dbias_fp32.to(bias.dtype) + + current_stream.wait_stream(dw_stream) + current_stream.wait_stream(db_stream) + + return dweight, dbias + + +def _finalize_layernorm_bwd_atomic_dw(*, dw_acc: Tensor, weight: Tensor) -> Tensor: + return dw_acc.clone() if weight.dtype == torch.float32 else dw_acc.to(weight.dtype) class _PtrLayernormFastLaunch: @@ -324,6 +524,8 @@ def _fallback_launch( eps: float, ) -> None: dtype_x = TORCH2CUTE_DTYPE[x.dtype] + dtype_w = TORCH2CUTE_DTYPE[weight.dtype] + dtype_b = TORCH2CUTE_DTYPE[bias.dtype] if bias is not None else None stream_handle = int(torch.cuda.current_stream().cuda_stream) stream = cuda.CUstream(stream_handle) ptr_x = rt.make_ptr( @@ -339,14 +541,14 @@ def _fallback_launch( assumed_align=self._assumed_align_xo, ) ptr_w = rt.make_ptr( - cutlass.Float32, + dtype_w, weight.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16, ) ptr_b = ( rt.make_ptr( - cutlass.Float32, + dtype_b, bias.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16, @@ -393,6 +595,8 @@ def _get_fast_ptr_layernorm_launcher( compiled: object, N: int, dtype_x: type[cutlass.Numeric], + dtype_w: type[cutlass.Numeric], + dtype_b: Optional[type[cutlass.Numeric]], has_bias: bool, has_rstd: bool, has_mean: bool, @@ -408,6 +612,8 @@ def _get_fast_ptr_layernorm_launcher( id(compiled), int(N), dtype_x, + dtype_w, + dtype_b, bool(has_bias), bool(has_rstd), bool(has_mean), @@ -426,13 +632,9 @@ def _get_fast_ptr_layernorm_launcher( ptr_out = rt.make_ptr( dtype_x, 0, mem_space=rt.AddressSpace.gmem, assumed_align=int(assumed_align_xo) ) - ptr_w = rt.make_ptr( - cutlass.Float32, 0, mem_space=rt.AddressSpace.gmem, assumed_align=16 - ) + ptr_w = rt.make_ptr(dtype_w, 0, mem_space=rt.AddressSpace.gmem, assumed_align=16) ptr_b = ( - rt.make_ptr( - cutlass.Float32, 0, mem_space=rt.AddressSpace.gmem, assumed_align=16 - ) + rt.make_ptr(dtype_b, 0, mem_space=rt.AddressSpace.gmem, assumed_align=16) if has_bias else None ) @@ -508,241 +710,1307 @@ def _get_fast_ptr_layernorm_launcher( return launcher -def _convert_row_major(t: Tensor) -> cute.Tensor: - """ - Convert a 2D row-major torch.Tensor to a CuTeDSL tensor with a compact, - dynamic layout on the leading dimension. - """ - return from_dlpack(t.detach(), assumed_align=16).mark_compact_shape_dynamic( - mode=0, - stride_order=(0, 1), - ) - - -class LayerNormSM100(_ReductionBase): - """ - SM100 LayerNorm forward kernel. +class _LayerNormBackwardSM103(_LiteRMSNormBackward): + """Self-contained Oink LayerNorm backward for SM103. - This mirrors `quack.layernorm.LayerNorm`'s schedule: - - Stage=2 pipeline: first pass computes mean, second pass computes - variance / rstd and normalization. - - Threads-per-row and cluster_n policy follow Quack's LayerNorm - heuristics to keep tensor-core friendly tiles across N. - - Optional `reload_from` hint enables reloading X from SMEM for large-N - shapes to shorten register lifetimes. + Contract: + - math: ``x_hat = (x - mean) * rstd``; + ``dx = (wdy - mean(wdy) - x_hat * mean(x_hat * wdy)) * rstd``; + ``dweight = sum(dout * x_hat)`` and optional ``dbias = sum(dout)``. + - precision: row reductions and partial parameter gradients use fp32; + ``dx`` is stored in the input dtype and final parameter gradients are + cast to the weight/bias dtype by the host finalizer. + - variants: partial-buffer accumulation is the default; atomic dW/dB are + compile-time flags enabled only by host-side shape policy. + - dispatch: the GB300 pointer path is fenced by dtype/layout/alignment + checks below; unsupported layouts fall back to local non-pointer kernels. - Differences vs Quack: - - Bias is optional and supported directly in the kernel. - - Dtype mapping and reduction helpers come from `lite_quack`. + The implementation intentionally depends only on repo-local helpers and + never dispatches to external kernels. """ - def __init__( - self, - dtype: type[cutlass.Numeric], - N: int, - *, - copy_bits_x: Optional[int] = None, - direct_gmem: bool = False, - ): - super().__init__(dtype, N, stage=2) # 2 stages for mean and var - # Default reload policy mirrors Quack: use SMEM reload only for - # very large hidden sizes. We keep this conservative for LayerNorm - # and tune primarily via threads-per-block / cluster_n. - self.reload_from: Optional[str] = None if N <= 16384 else "smem" - # SM100 tuning: for DSv3 hidden sizes where we fuse mean+var stats, - # delay loading fp32 weights/bias until after the reductions to lower - # register pressure. - self.delay_w_load: bool = bool(N in (4096, 6144, 7168, 8192)) - self.copy_bits_x: Optional[int] = ( - int(copy_bits_x) if copy_bits_x is not None else None - ) - self.direct_gmem: bool = bool(direct_gmem) + def __init__(self, dtype: type[cutlass.Numeric], N: int): + super().__init__(dtype, N) + self.atomic_dw = False def _get_num_threads(self) -> int: nt = getattr(self, "_nt_override", None) if nt is not None: return int(nt) - return super()._get_num_threads() + return 128 if self.N <= 4096 else 256 def _calculate_threads_per_row(self) -> int: tpr = getattr(self, "_tpr_override", None) if tpr is not None: return int(tpr) - # Match Quack's LayerNorm threads-per-row buckets. N = self.N - if N in (4096, 6144): - return 128 - return ( - 8 - if N <= 64 - else ( - 16 - if N <= 128 - else ( - 32 - if N <= 3072 - else (64 if N <= 6144 else (128 if N <= 16384 else 256)) - ) - ) - ) + for limit, threads in [ + (64, 8), + (128, 16), + (256, 32), + (512, 64), + (4096, 128), + ]: + if N <= limit: + return threads + return 256 def _set_cluster_n(self) -> None: - # Cluster_n policy mirrors quack.layernorm.LayerNorm._set_cluster_n. + cn = getattr(self, "_cluster_n_override", None) + if cn is not None: + self.cluster_n = int(cn) + return + N = self.N - if const_expr(self.dtype.width == 16): - cluster_n = ( - 1 - if N <= 16 * 1024 - else ( - 2 - if N <= 32 * 1024 - else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16)) - ) - ) + if N <= 8192: + cluster_n = 1 + elif self.dtype.width == 16: + if N <= 16 * 1024: + cluster_n = 2 + elif N <= 32 * 1024: + cluster_n = 2 + elif N <= 64 * 1024: + cluster_n = 4 + elif N <= 128 * 1024: + cluster_n = 8 + else: + cluster_n = 16 else: - cluster_n = ( - 1 - if N <= 32 * 1024 - else ( - 2 - if N <= 64 * 1024 - else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16)) - ) - ) + if N <= 32 * 1024: + cluster_n = 1 + elif N <= 64 * 1024: + cluster_n = 2 + elif N <= 128 * 1024: + cluster_n = 4 + elif N <= 256 * 1024: + cluster_n = 8 + else: + cluster_n = 16 self.cluster_n = cluster_n @cute.jit def __call__( self, mX: cute.Tensor, - mW: cute.Tensor, - mB: Optional[cute.Tensor], - mO: cute.Tensor, - mRstd: Optional[cute.Tensor], - mMean: Optional[cute.Tensor], + mW: cute.Tensor | None, + mdO: cute.Tensor, + mMean: cute.Tensor, + mRstd: cute.Tensor, + mdX: cute.Tensor, + mdW: cute.Tensor | None, + mdB: cute.Tensor | None, + sm_count: Int32, stream: cuda.CUstream, - eps: Float32 = 1e-6, ): - assert mX.element_type == self.dtype - assert mO.element_type == self.dtype + semistatic_shape = (*mX.shape[:-1], self.N) - # Tiling and cluster policy (mirrors Quack LayerNorm). + def new_stride(t): + return ( + cute.assume(t.stride[0], divby=128 // t.element_type.width), + t.stride[1], + ) + + mX, mdO, mdX = [ + cute.make_tensor( + t.iterator, cute.make_layout(semistatic_shape, stride=new_stride(t)) + ) + for t in (mX, mdO, mdX) + ] self._set_cluster_n() largest_dtype_width = const_expr( max( - t.element_type.width - for t in (mX, mW, mB, mO, mRstd, mMean) - if t is not None + mX.element_type.width, + mW.element_type.width if mW is not None else 0, + mdO.element_type.width, + mdX.element_type.width, ) ) - # Match Quack's unified RMSNorm/LayerNorm kernel: pick vecsize based on - # the widest dtype participating in the op (e.g. fp32 weights => fp16 - # X uses 64b vectorization). - vecsize = math.gcd(self.N, 128 // largest_dtype_width) - default_copy_bits_x = vecsize * self.dtype.width - num_copy_bits_x = ( - int(self.copy_bits_x) - if self.copy_bits_x is not None - else default_copy_bits_x - ) - tiler_mn, tv_layout = self._get_tv_layout(num_copy_bits=num_copy_bits_x) + num_copy_bits = const_expr(128 // largest_dtype_width * mX.element_type.width) + tiler_mn, tv_layout = self._get_tv_layout(num_copy_bits=int(num_copy_bits)) num_threads = ( cute.size(tv_layout, mode=[0]) if _KERNEL_ACCEPTS_LAYOUT_ARGS else self._get_num_threads() ) num_warps = num_threads // cute.arch.WARP_SIZE - - # Expand weight / bias to match tiler_mn[0] rows per CTA. - mW = cute.make_tensor( - mW.iterator, - cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,))), - ) - if const_expr(mB is not None): - mB = cute.make_tensor( - mB.iterator, - cute.prepend(mB.layout, cute.make_layout((tiler_mn[0],), stride=(0,))), - ) - if const_expr(mRstd is not None): - mRstd = cute.make_tensor( - mRstd.iterator, - cute.append(mRstd.layout, cute.make_layout((self.N,), stride=(0,))), - ) - if const_expr(mMean is not None): - mMean = cute.make_tensor( - mMean.iterator, - cute.append(mMean.layout, cute.make_layout((self.N,), stride=(0,))), + if const_expr(mW is not None): + mW_expanded_layout = cute.prepend( + mW.layout, + cute.make_layout((tiler_mn[0],), stride=(0,)), ) + mW = cute.make_tensor(mW.iterator, mW_expanded_layout) + num_blocks = sm_count kernel = ( - self.kernel( - mX, - mW, - mB, - mO, - mRstd, - mMean, - eps, - tv_layout, - tiler_mn, - const_expr(self.reload_from), - const_expr(self.delay_w_load), - ) + self.kernel(mX, mW, mdO, mMean, mRstd, mdX, mdW, mdB, tv_layout, tiler_mn) if _KERNEL_ACCEPTS_LAYOUT_ARGS - else self.kernel( - mX, - mW, - mB, - mO, - mRstd, - mMean, - eps, - ) + else self.kernel(mX, mW, mdO, mMean, mRstd, mdX, mdW, mdB) ) kernel.launch( - grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1], + grid=[num_blocks, self.cluster_n, 1], block=[num_threads, 1, 1], - cluster=[ - 1, - self.cluster_n, - 1, - ] - if const_expr(self.cluster_n > 1) - else None, - smem=self._smem_size_in_bytes(tiler_mn, num_warps), + cluster=[1, self.cluster_n, 1] if self.cluster_n > 1 else None, + smem=self._smem_size_in_bytes( + tiler_mn, num_warps, do_dtype=mdO.element_type + ), stream=stream, ) @cute.jit - def launch_from_ptrs( + def _kernel_impl( self, - ptr_x: cute.Pointer, - ptr_w: cute.Pointer, - ptr_b: Optional[cute.Pointer], - ptr_out: cute.Pointer, - ptr_rstd: Optional[cute.Pointer], - ptr_mean: Optional[cute.Pointer], - M: Int32, - ld: Int32, - stream: cuda.CUstream, - eps: Float32 = 1e-6, - ) -> None: - """Pointer-based entrypoint that bypasses DLPack conversions. + mX: cute.Tensor, + mW: cute.Tensor | None, + mdO: cute.Tensor, + mMean: cute.Tensor, + mRstd: cute.Tensor, + mdX: cute.Tensor, + mdW: cute.Tensor | None, + mdB: cute.Tensor | None, + tv_layout: cute.Layout, + tiler_mn: cute.Shape, + ): + tidx, _, _ = cute.arch.thread_idx() + bidx_start, _, _ = cute.arch.block_idx() + gdim, _, _ = cute.arch.grid_dim() + if const_expr(self.cluster_n > 1): + cluster_y = cute.arch.block_idx()[1] + else: + cluster_y = const_expr(0) - This reconstructs cute.Tensor views from raw device pointers + explicit - layouts inside the JIT graph, reusing the tuned LayerNormSM100 schedule. - """ - # Mirror Quack-style divisibility contracts so the compiler can prove - # alignment for vectorized loads/stores (and cp.async when enabled). - divby = ( - int(self.copy_bits_x) // self.dtype.width - if const_expr(self.copy_bits_x is not None) - else (128 // self.dtype.width) - ) - ld_assumed = cute.assume(ld, divby=divby) - # Match `mark_compact_shape_dynamic(mode=0, ...)`: M is dynamic, N is static. - layout_mn = cute.make_layout((M, self.N), stride=(ld_assumed, 1)) - layout_n = cute.make_layout((self.N,), stride=(1,)) + shape = mX.shape + M = shape[0] + is_even_N = const_expr(shape[1] == tiler_mn[1] * self.cluster_n) + + idX = cute.make_identity_tensor(shape) + + smem = cutlass.utils.SmemAllocator() + smem_layout = cute.make_ordered_layout( + (tiler_mn[0], tiler_mn[1], 2), order=(1, 0, 2) + ) + sX = smem.allocate_tensor(mX.element_type, smem_layout, byte_alignment=16) + sdO = smem.allocate_tensor(mdO.element_type, smem_layout, byte_alignment=16) + reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar( + smem, + tv_layout, + is_persistent=True, + ) + if const_expr(mbar_ptr is not None): + mbar_full_ptr, mbar_empty_ptr = mbar_ptr, mbar_ptr + 2 + else: + mbar_full_ptr, mbar_empty_ptr = None, None + + num_copy_elems_X = ( + tv_layout.shape[1] + if cutlass.const_expr(cute.rank(tv_layout.shape[1]) == 1) + else tv_layout.shape[1][0] + ) + threads_per_row = ( + tv_layout.shape[0] + if cutlass.const_expr(cute.rank(tv_layout.shape[0]) == 1) + else tv_layout.shape[0][0] + ) + copy_atom_load_X = get_copy_atom( + mX.element_type, num_copy_elems_X, is_async=False + ) + thr_layout = cute.make_ordered_layout( + (tiler_mn[0], threads_per_row), order=(1, 0) + ) + val_layout = cute.make_layout((1, num_copy_elems_X)) + thr_copy_X = cute.make_tiled_copy_tv( + copy_atom_load_X, thr_layout, val_layout + ).get_slice(tidx) + copy_fn = partial(_quack_copy, num_copy_elems=num_copy_elems_X) + + gX, gdO, gdX, cX = [ + cute.local_tile(mT, tiler_mn, (None, cluster_y)) + for mT in (mX, mdO, mdX, idX) + ] + gW = cute.local_tile(mW, tiler_mn, (0, cluster_y)) if mW is not None else None + gdW, gdB = [ + cute.local_tile(mT, (1, tiler_mn[1]), (bidx_start, cluster_y)) + if const_expr(mT is not None) + else None + for mT in (mdW, mdB) + ] + + tXgX = thr_copy_X.partition_S(gX) + tXsX = thr_copy_X.partition_D(sX) + tXgdO = thr_copy_X.partition_S(gdO) + tXsdO = thr_copy_X.partition_D(sdO) + tXgdX = thr_copy_X.partition_D(gdX) + tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None, None] + + tXrX, tXrdO, tXrdX = [ + cute.make_fragment_like(thr[None, None, None, 0]) + for thr in (tXgX, tXgdO, tXgdX) + ] + + tXpX = ( + predicate_k(thr_copy_X.partition_S(cX[None, None, 0]), limit=shape[1]) + if not is_even_N + else None + ) + + tXgdW, tXrdW = None, None + tXgdB, tXrdB = None, None + if const_expr(mdW is not None): + tXgdW = thr_copy_X.partition_S(gdW) + tXrdW = cute.make_fragment_like(tXgdW, Float32) + if const_expr(mdB is not None): + tXgdB = thr_copy_X.partition_S(gdB) + tXrdB = cute.make_fragment_like(tXgdB, Float32) + + num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE + self._initialize_cluster(tidx, mbar_ptr, num_warps, is_persistent=True) + + tXrW = None + if const_expr(mW is not None): + tXgW = thr_copy_X.partition_S(gW) + tXrW = cute.make_fragment_like(tXgW) + if not is_even_N: + tXrW.fill(0.0) + copy_fn(tXgW, tXrW, pred=tXpX) + + row = tXcX[None, None, None, bidx_start][0][0] + if row < M: + tXgX_cur = coord_offset_i64(bidx_start, tXgX, dim=3)[None, None, None, 0] + tXgdO_cur = coord_offset_i64(bidx_start, tXgdO, dim=3)[None, None, None, 0] + copy_fn( + tXgX_cur, + tXsX[None, None, None, 0], + pred=tXpX, + is_async=True, + ) + copy_fn( + tXgdO_cur, + tXsdO[None, None, None, 0], + pred=tXpX, + is_async=True, + ) + elif tiler_mn[0] > 1: + fill_oob(tXsX[None, None, None, 0], None, fill_value=mX.element_type.zero) + fill_oob(tXsdO[None, None, None, 0], None, fill_value=mdO.element_type.zero) + cute.arch.cp_async_commit_group() + + if const_expr(self.cluster_n > 1): + cute.arch.cluster_wait() + + if const_expr(mdW is not None): + tXrdW.fill(0.0) + if const_expr(mdB is not None): + tXrdB.fill(0.0) + stage = Int32(0) + producer_phase = Int32(1) + consumer_phase = Int32(0) + for bidx in cutlass.range(bidx_start, cute.ceil_div(M, tiler_mn[0]), gdim): + row = tXcX[None, None, None, bidx][0][0] + if row + gdim * tiler_mn[0] < M: + tXgX_cur = coord_offset_i64(bidx + gdim, tXgX, dim=3)[ + None, None, None, 0 + ] + tXgdO_cur = coord_offset_i64(bidx + gdim, tXgdO, dim=3)[ + None, None, None, 0 + ] + copy_fn( + tXgX_cur, + tXsX[None, None, None, stage ^ 1], + pred=tXpX, + is_async=True, + ) + copy_fn( + tXgdO_cur, + tXsdO[None, None, None, stage ^ 1], + pred=tXpX, + is_async=True, + ) + elif tiler_mn[0] > 1: + fill_oob( + tXsX[None, None, None, stage ^ 1], + None, + fill_value=mX.element_type.zero, + ) + fill_oob( + tXsdO[None, None, None, stage ^ 1], + None, + fill_value=mdO.element_type.zero, + ) + cute.arch.cp_async_commit_group() + rstd_val = cutlass.Float.zero + mean_val = cutlass.Float.zero + if row < M or tiler_mn[0] == 1: + rstd_val = mRstd[row] + mean_val = mMean[row] + cute.arch.cp_async_wait_group(1) + cute.autovec_copy(tXsX[None, None, None, stage], tXrX) + x = tXrX.load().to(cute.Float32) + cute.autovec_copy(tXsdO[None, None, None, stage], tXrdO) + dout = tXrdO.load().to(cute.Float32) + x_hat = (x - mean_val) * rstd_val + wdy = dout + if const_expr(mW is not None): + wdy *= tXrW.load().to(Float32) + if const_expr(self.cluster_n > 1): + cute.arch.mbarrier_wait(mbar_empty_ptr + stage, producer_phase) + mean_xhat_wdy = ( + row_reduce_add( + x_hat * wdy, + threads_per_row, + reduction_buffer[None, None, stage], + (mbar_full_ptr + stage if const_expr(self.cluster_n > 1) else None), + phase=consumer_phase, + init_val=0.0, + ) + / shape[1] + ) + mean_wdy = ( + row_reduce_add( + wdy, + threads_per_row, + reduction_buffer[None, None, stage ^ 1], + None, + init_val=0.0, + ) + / shape[1] + ) + + if const_expr(self.cluster_n > 1): + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + cute.arch.sync_warp() + lane_idx = cute.arch.lane_idx() + if lane_idx < self.cluster_n: + cute.arch.mbarrier_arrive( + mbar_empty_ptr + stage, + peer_cta_rank_in_cluster=lane_idx, + ) + + if const_expr(self.reload_wdy == "smem"): + cute.autovec_copy(tXsdO[None, None, None, stage], tXrdO) + dout = tXrdO.load().to(cute.Float32) + wdy = dout + if const_expr(mW is not None): + wdy *= tXrW.load().to(Float32) + + dx = (wdy - mean_wdy - x_hat * mean_xhat_wdy) * rstd_val + tXrdX.store(dx.to(tXrdX.element_type)) + if row < M or tiler_mn[0] == 1: + tXgdX_cur = coord_offset_i64(bidx, tXgdX, dim=3)[None, None, None, 0] + copy_fn(tXrdX, tXgdX_cur, pred=tXpX) + if const_expr(mdW is not None): + tXrdW.store(tXrdW.load() + dout * x_hat) + if const_expr(mdB is not None): + tXrdB.store(tXrdB.load() + dout) + + stage ^= 1 + if stage == 0: + consumer_phase ^= 1 + producer_phase ^= 1 + + if const_expr(tiler_mn[0] > 1): + if const_expr(mdW is not None): + sdW = cute.make_tensor( + cute.recast_ptr(sX.iterator, dtype=cute.Float32), + cute.make_ordered_layout(tiler_mn, order=(1, 0)), + ) + tXsdW = thr_copy_X.partition_D(sdW) + cute.arch.barrier() + row0 = tXcX[None, None, None, 0][0][0] + if row0 > 0: + cute.autovec_copy(tXrdW, tXsdW) + cute.arch.barrier() + if row0 == 0: + for i in cutlass.range_constexpr(1, const_expr(tiler_mn[0])): + tXrdW_other = cute.make_fragment_like(tXrdW) + tXsdW_other = cute.make_tensor( + tXsdW.iterator + i * sdW.stride[0], + tXsdW.layout, + ) + cute.autovec_copy(tXsdW_other, tXrdW_other) + tXrdW.store(tXrdW.load() + tXrdW_other.load()) + if const_expr(self.atomic_dw): + atomic_add_tensor_f32(tXrdW, tXgdW, pred=tXpX) + else: + copy_fn(tXrdW, tXgdW, pred=tXpX) + cute.arch.barrier() + if const_expr(mdB is not None): + sdB = cute.make_tensor( + cute.recast_ptr(sX.iterator, dtype=cute.Float32), + cute.make_ordered_layout(tiler_mn, order=(1, 0)), + ) + tXsdB = thr_copy_X.partition_D(sdB) + cute.arch.barrier() + row0 = tXcX[None, None, None, 0][0][0] + if row0 > 0: + cute.autovec_copy(tXrdB, tXsdB) + cute.arch.barrier() + if row0 == 0: + for i in cutlass.range_constexpr(1, const_expr(tiler_mn[0])): + tXrdB_other = cute.make_fragment_like(tXrdB) + tXsdB_other = cute.make_tensor( + tXsdB.iterator + i * sdB.stride[0], + tXsdB.layout, + ) + cute.autovec_copy(tXsdB_other, tXrdB_other) + tXrdB.store(tXrdB.load() + tXrdB_other.load()) + copy_fn(tXrdB, tXgdB, pred=tXpX) + else: + if const_expr(mdW is not None): + if const_expr(self.atomic_dw): + atomic_add_tensor_f32(tXrdW, tXgdW, pred=tXpX) + else: + copy_fn(tXrdW, tXgdW, pred=tXpX) + if const_expr(mdB is not None): + copy_fn(tXrdB, tXgdB, pred=tXpX) + + if const_expr(self.cluster_n > 1): + stage ^= 1 + if stage == 0: + producer_phase ^= 1 + cute.arch.mbarrier_wait(mbar_empty_ptr + stage, producer_phase) + + if _KERNEL_ACCEPTS_LAYOUT_ARGS: + + @cute.kernel + def kernel( + self, + mX: cute.Tensor, + mW: cute.Tensor | None, + mdO: cute.Tensor, + mMean: cute.Tensor, + mRstd: cute.Tensor, + mdX: cute.Tensor, + mdW: cute.Tensor | None, + mdB: cute.Tensor | None, + tv_layout: cute.Layout, + tiler_mn: cute.Shape, + ): + self._kernel_impl( + mX, + mW, + mdO, + mMean, + mRstd, + mdX, + mdW, + mdB, + tv_layout, + tiler_mn, + ) + else: + + @cute.kernel + def kernel( + self, + mX: cute.Tensor, + mW: cute.Tensor | None, + mdO: cute.Tensor, + mMean: cute.Tensor, + mRstd: cute.Tensor, + mdX: cute.Tensor, + mdW: cute.Tensor | None, + mdB: cute.Tensor | None, + ): + largest_dtype_width = const_expr( + max( + mX.element_type.width, + mW.element_type.width if mW is not None else 0, + mdO.element_type.width, + mdX.element_type.width, + ) + ) + tiler_mn, tv_layout = self._get_tv_layout( + num_copy_bits=128 // largest_dtype_width * mX.element_type.width + ) + self._kernel_impl( + mX, + mW, + mdO, + mMean, + mRstd, + mdX, + mdW, + mdB, + tv_layout, + tiler_mn, + ) + + @cute.jit + def launch_from_ptrs( + self, + ptr_x: cute.Pointer, + ptr_w: cute.Pointer, + ptr_dout: cute.Pointer, + ptr_rstd: cute.Pointer, + ptr_mean: cute.Pointer, + ptr_dx: cute.Pointer, + ptr_dw_acc: cute.Pointer, + ptr_db_acc: Optional[cute.Pointer], + M: Int32, + ld: Int32, + sm_count: Int32, + stream: cuda.CUstream, + ) -> None: + ld_assumed = cute.assume(ld, divby=256 // self.dtype.width) + layout_mn = cute.make_layout((M, self.N), stride=(ld_assumed, 1)) + layout_n = cute.make_layout((self.N,), stride=(1,)) + layout_m = cute.make_layout((M,), stride=(1,)) + layout_dw = cute.make_layout( + (sm_count, self.N), stride=((0 if self.atomic_dw else self.N), 1) + ) + layout_db = cute.make_layout((sm_count, self.N), stride=(self.N, 1)) + + mX = cute.make_tensor(ptr_x, layout_mn) + mW = cute.make_tensor(ptr_w, layout_n) + mdO = cute.make_tensor(ptr_dout, layout_mn) + mRstd = cute.make_tensor(ptr_rstd, layout_m) + mMean = cute.make_tensor(ptr_mean, layout_m) + mdX = cute.make_tensor(ptr_dx, layout_mn) + mdW = cute.make_tensor(ptr_dw_acc, layout_dw) + mdB = ( + cute.make_tensor(ptr_db_acc, layout_db) + if const_expr(ptr_db_acc is not None) + else None + ) + + self.__call__( + mX, + mW, + mdO, + mMean, + mRstd, + mdX, + mdW, + mdB, + sm_count, + stream, + ) + + +class _PtrLayernormBwdFastLaunch: + def __init__( + self, + *, + compiled: object, + executor: object, + capi_func: object, + ptr_x: object, + ptr_w: object, + ptr_dout: object, + ptr_rstd: object, + ptr_mean: object, + ptr_dx: object, + ptr_dw_partial: object, + ptr_db_partial: Optional[object], + arg_m: StableI32Arg, + arg_ld: StableI32Arg, + arg_sm_count: StableI32Arg, + stream: cuda.CUstream, + assumed_align_x: int, + assumed_align_w: int, + assumed_align_dw: int, + weight_dtype: type[cutlass.Numeric], + packed_args: object, + keepalive: tuple[object, ...], + ): + self._compiled = compiled + self._executor = executor + self._capi_func = capi_func + self._ptr_x = ptr_x + self._ptr_w = ptr_w + self._ptr_dout = ptr_dout + self._ptr_rstd = ptr_rstd + self._ptr_mean = ptr_mean + self._ptr_dx = ptr_dx + self._ptr_dw_partial = ptr_dw_partial + self._ptr_db_partial = ptr_db_partial + self._arg_m = arg_m + self._arg_ld = arg_ld + self._arg_sm_count = arg_sm_count + self._stream = stream + self._assumed_align_x = int(assumed_align_x) + self._assumed_align_w = int(assumed_align_w) + self._assumed_align_dw = int(assumed_align_dw) + self._weight_dtype = weight_dtype + self._packed_args = packed_args + self._keepalive = keepalive + + self._use_fast_launch = True + self._cuda_result = getattr(executor, "cuda_result", None) + + self._last_x_ptr = -1 + self._last_w_ptr = -1 + self._last_dout_ptr = -1 + self._last_rstd_ptr = -1 + self._last_mean_ptr = -1 + self._last_dx_ptr = -1 + self._last_dw_ptr = -1 + self._last_db_ptr = -1 + self._last_m = -1 + self._last_ld = -1 + self._last_sm_count = -1 + + def launch( + self, + *, + x: Tensor, + weight: Tensor, + dout: Tensor, + rstd: Tensor, + mean: Tensor, + dx: Tensor, + dw_partial: Tensor, + db_partial: Optional[Tensor], + M: int, + ld: int, + sm_count: int, + ) -> None: + if not fast_launch_enabled() or not self._use_fast_launch: + self._fallback_launch( + x=x, + weight=weight, + dout=dout, + rstd=rstd, + mean=mean, + dx=dx, + dw_partial=dw_partial, + db_partial=db_partial, + M=M, + ld=ld, + sm_count=sm_count, + ) + return + + def _update_ptr(last_name: str, ptr_obj: object, value: int) -> bool: + last_attr = ( + last_name if last_name.startswith("_last_") else f"_last_{last_name}" + ) + if value == getattr(self, last_attr): + return True + try: + set_runtime_ptr(ptr_obj, value) + setattr(self, last_attr, value) + return True + except AttributeError: + self._disable_fast_launch() + return False + + if not _update_ptr("x_ptr", self._ptr_x, x.data_ptr()): + self._fallback_launch( + x=x, + weight=weight, + dout=dout, + rstd=rstd, + mean=mean, + dx=dx, + dw_partial=dw_partial, + db_partial=db_partial, + M=M, + ld=ld, + sm_count=sm_count, + ) + return + if not _update_ptr("w_ptr", self._ptr_w, weight.data_ptr()): + self._fallback_launch( + x=x, + weight=weight, + dout=dout, + rstd=rstd, + mean=mean, + dx=dx, + dw_partial=dw_partial, + db_partial=db_partial, + M=M, + ld=ld, + sm_count=sm_count, + ) + return + if not _update_ptr("dout_ptr", self._ptr_dout, dout.data_ptr()): + self._fallback_launch( + x=x, + weight=weight, + dout=dout, + rstd=rstd, + mean=mean, + dx=dx, + dw_partial=dw_partial, + db_partial=db_partial, + M=M, + ld=ld, + sm_count=sm_count, + ) + return + if not _update_ptr("rstd_ptr", self._ptr_rstd, rstd.data_ptr()): + self._fallback_launch( + x=x, + weight=weight, + dout=dout, + rstd=rstd, + mean=mean, + dx=dx, + dw_partial=dw_partial, + db_partial=db_partial, + M=M, + ld=ld, + sm_count=sm_count, + ) + return + if not _update_ptr("mean_ptr", self._ptr_mean, mean.data_ptr()): + self._fallback_launch( + x=x, + weight=weight, + dout=dout, + rstd=rstd, + mean=mean, + dx=dx, + dw_partial=dw_partial, + db_partial=db_partial, + M=M, + ld=ld, + sm_count=sm_count, + ) + return + if not _update_ptr("dx_ptr", self._ptr_dx, dx.data_ptr()): + self._fallback_launch( + x=x, + weight=weight, + dout=dout, + rstd=rstd, + mean=mean, + dx=dx, + dw_partial=dw_partial, + db_partial=db_partial, + M=M, + ld=ld, + sm_count=sm_count, + ) + return + if not _update_ptr("dw_ptr", self._ptr_dw_partial, dw_partial.data_ptr()): + self._fallback_launch( + x=x, + weight=weight, + dout=dout, + rstd=rstd, + mean=mean, + dx=dx, + dw_partial=dw_partial, + db_partial=db_partial, + M=M, + ld=ld, + sm_count=sm_count, + ) + return + + if self._ptr_db_partial is not None and db_partial is not None: + if not _update_ptr("db_ptr", self._ptr_db_partial, db_partial.data_ptr()): + self._fallback_launch( + x=x, + weight=weight, + dout=dout, + rstd=rstd, + mean=mean, + dx=dx, + dw_partial=dw_partial, + db_partial=db_partial, + M=M, + ld=ld, + sm_count=sm_count, + ) + return + + if M != self._last_m: + self._arg_m.set(M) + self._last_m = M + if ld != self._last_ld: + self._arg_ld.set(ld) + self._last_ld = ld + if sm_count != self._last_sm_count: + self._arg_sm_count.set(sm_count) + self._last_sm_count = sm_count + + if self._cuda_result is not None: + self._cuda_result.value = 0 + ret = self._capi_func(self._packed_args) # type: ignore[misc] + if ret != 0: + raise RuntimeError(f"CuTeDSL capi_func returned non-zero: {ret}") + if self._cuda_result is not None: + err = int(self._cuda_result.value) + if err != 0: + raise RuntimeError(f"CuTeDSL kernel launch failed (cuda_result={err})") + + def _disable_fast_launch(self) -> None: + self._use_fast_launch = False + disable_fast_launch() + + def _fallback_launch( + self, + *, + x: Tensor, + weight: Tensor, + dout: Tensor, + rstd: Tensor, + mean: Tensor, + dx: Tensor, + dw_partial: Tensor, + db_partial: Optional[Tensor], + M: int, + ld: int, + sm_count: int, + ) -> None: + dtype = TORCH2CUTE_DTYPE[x.dtype] + ptr_x = rt.make_ptr( + dtype, + x.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=self._assumed_align_x, + ) + ptr_w = rt.make_ptr( + self._weight_dtype, + weight.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=self._assumed_align_w, + ) + ptr_dout = rt.make_ptr( + dtype, + dout.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=self._assumed_align_x, + ) + ptr_rstd = rt.make_ptr( + TORCH2CUTE_DTYPE[rstd.dtype], + rstd.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=4, + ) + ptr_mean = rt.make_ptr( + TORCH2CUTE_DTYPE[mean.dtype], + mean.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=4, + ) + ptr_dx = rt.make_ptr( + dtype, + dx.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=self._assumed_align_x, + ) + ptr_dw_partial = rt.make_ptr( + TORCH2CUTE_DTYPE[dw_partial.dtype], + dw_partial.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=self._assumed_align_dw, + ) + ptr_db_partial = ( + rt.make_ptr( + TORCH2CUTE_DTYPE[db_partial.dtype], + db_partial.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=self._assumed_align_dw, + ) + if db_partial is not None + else None + ) + self._compiled( + ptr_x, + ptr_w, + ptr_dout, + ptr_rstd, + ptr_mean, + ptr_dx, + ptr_dw_partial, + ptr_db_partial, + Int32(M), + Int32(ld), + Int32(sm_count), + self._stream, + ) + + +def _get_fast_ptr_layernorm_bwd_launcher( + *, + compiled: object, + dtype: type[cutlass.Numeric], + weight_dtype: type[cutlass.Numeric], + N: int, + device_index: int, + stream_handle: int, + has_db_partial: bool, + assumed_align_x: int, + assumed_align_w: int, + assumed_align_dw: int, +) -> Optional[_PtrLayernormBwdFastLaunch]: + if not fast_launch_enabled(): + return None + + key = ( + "layernorm_bwd_ptr_fast", + id(compiled), + int(N), + dtype, + weight_dtype, + bool(has_db_partial), + int(device_index), + int(stream_handle), + int(assumed_align_x), + int(assumed_align_w), + int(assumed_align_dw), + ) + cache = _tls_fast_launch_cache() + cached = cache.get(key) + if cached is not None: + return cached # type: ignore[return-value] + + ptr_x = rt.make_ptr( + dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=int(assumed_align_x) + ) + ptr_w = rt.make_ptr( + weight_dtype, + 0, + mem_space=rt.AddressSpace.gmem, + assumed_align=int(assumed_align_w), + ) + ptr_dout = rt.make_ptr( + dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=int(assumed_align_x) + ) + ptr_rstd = rt.make_ptr( + cutlass.Float32, + 0, + mem_space=rt.AddressSpace.gmem, + assumed_align=4, + ) + ptr_mean = rt.make_ptr( + cutlass.Float32, + 0, + mem_space=rt.AddressSpace.gmem, + assumed_align=4, + ) + ptr_dx = rt.make_ptr( + dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=int(assumed_align_x) + ) + ptr_dw_partial = rt.make_ptr( + cutlass.Float32, + 0, + mem_space=rt.AddressSpace.gmem, + assumed_align=int(assumed_align_dw), + ) + ptr_db_partial = ( + rt.make_ptr( + cutlass.Float32, + 0, + mem_space=rt.AddressSpace.gmem, + assumed_align=int(assumed_align_dw), + ) + if has_db_partial + else None + ) + + arg_m = StableI32Arg(0) + arg_ld = StableI32Arg(N) + arg_sm_count = StableI32Arg(0) + stream = cuda.CUstream(int(stream_handle)) + executor = compiled.to(device_index) # type: ignore[attr-defined] + + try: + exe_args, adapted_args = executor.generate_execution_args( + ptr_x, + ptr_w, + ptr_dout, + ptr_rstd, + ptr_mean, + ptr_dx, + ptr_dw_partial, + ptr_db_partial, + arg_m, + arg_ld, + arg_sm_count, + stream, + ) + packed_args = executor._get_invoke_packed_args(list(exe_args)) # type: ignore[attr-defined] + capi_func = compiled.capi_func # type: ignore[attr-defined] + except AttributeError: + disable_fast_launch() + return None + + keepalive: tuple[object, ...] = ( + executor, + ptr_x, + ptr_w, + ptr_dout, + ptr_rstd, + ptr_mean, + ptr_dx, + ptr_dw_partial, + ptr_db_partial, + arg_m, + arg_ld, + arg_sm_count, + stream, + *adapted_args, + ) + launcher = _PtrLayernormBwdFastLaunch( + compiled=compiled, + executor=executor, + capi_func=capi_func, + ptr_x=ptr_x, + ptr_w=ptr_w, + ptr_dout=ptr_dout, + ptr_rstd=ptr_rstd, + ptr_mean=ptr_mean, + ptr_dx=ptr_dx, + ptr_dw_partial=ptr_dw_partial, + ptr_db_partial=ptr_db_partial, + arg_m=arg_m, + arg_ld=arg_ld, + arg_sm_count=arg_sm_count, + stream=stream, + assumed_align_x=int(assumed_align_x), + assumed_align_w=int(assumed_align_w), + assumed_align_dw=int(assumed_align_dw), + weight_dtype=weight_dtype, + packed_args=packed_args, + keepalive=keepalive, + ) + cache[key] = launcher + return launcher + + +def _convert_row_major(t: Tensor) -> cute.Tensor: + """ + Convert a 2D row-major torch.Tensor to a CuTeDSL tensor with a compact, + dynamic layout on the leading dimension. + """ + return from_dlpack(t.detach(), assumed_align=16).mark_compact_shape_dynamic( + mode=0, + stride_order=(0, 1), + ) + + +class LayerNormSM100(_ReductionBase): + """ + SM100 LayerNorm forward kernel. + + This mirrors `quack.layernorm.LayerNorm`'s schedule: + - Stage=2 pipeline: first pass computes mean, second pass computes + variance / rstd and normalization. + - Threads-per-row and cluster_n policy follow Quack's LayerNorm + heuristics to keep tensor-core friendly tiles across N. + - Optional `reload_from` hint enables reloading X from SMEM for large-N + shapes to shorten register lifetimes. + + Differences vs Quack: + - Bias is optional and supported directly in the kernel. + - Dtype mapping and reduction helpers come from `lite_quack`. + """ + + def __init__( + self, + dtype: type[cutlass.Numeric], + N: int, + *, + copy_bits_x: Optional[int] = None, + direct_gmem: bool = False, + ): + super().__init__(dtype, N, stage=2) # 2 stages for mean and var + # Default reload policy mirrors Quack: use SMEM reload only for + # very large hidden sizes. We keep this conservative for LayerNorm + # and tune primarily via threads-per-block / cluster_n. + self.reload_from: Optional[str] = None if N <= 16384 else "smem" + # SM100 tuning: for DSv3 hidden sizes where we fuse mean+var stats, + # delay loading fp32 weights/bias until after the reductions to lower + # register pressure. + self.delay_w_load: bool = bool(N in (4096, 6144, 7168, 8192)) + self.copy_bits_x: Optional[int] = ( + int(copy_bits_x) if copy_bits_x is not None else None + ) + self.direct_gmem: bool = bool(direct_gmem) + + def _get_num_threads(self) -> int: + nt = getattr(self, "_nt_override", None) + if nt is not None: + return int(nt) + return super()._get_num_threads() + + def _calculate_threads_per_row(self) -> int: + tpr = getattr(self, "_tpr_override", None) + if tpr is not None: + return int(tpr) + # Match Quack's LayerNorm threads-per-row buckets. + N = self.N + if N in (4096, 6144): + return 128 + return ( + 8 + if N <= 64 + else ( + 16 + if N <= 128 + else ( + 32 + if N <= 3072 + else (64 if N <= 6144 else (128 if N <= 16384 else 256)) + ) + ) + ) + + def _set_cluster_n(self) -> None: + # Cluster_n policy mirrors quack.layernorm.LayerNorm._set_cluster_n. + N = self.N + if const_expr(self.dtype.width == 16): + cluster_n = ( + 1 + if N <= 16 * 1024 + else ( + 2 + if N <= 32 * 1024 + else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16)) + ) + ) + else: + cluster_n = ( + 1 + if N <= 32 * 1024 + else ( + 2 + if N <= 64 * 1024 + else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16)) + ) + ) + self.cluster_n = cluster_n + + @cute.jit + def __call__( + self, + mX: cute.Tensor, + mW: cute.Tensor, + mB: Optional[cute.Tensor], + mO: cute.Tensor, + mRstd: Optional[cute.Tensor], + mMean: Optional[cute.Tensor], + stream: cuda.CUstream, + eps: Float32 = 1e-6, + ): + assert mX.element_type == self.dtype + assert mO.element_type == self.dtype + + # Tiling and cluster policy (mirrors Quack LayerNorm). + self._set_cluster_n() + largest_dtype_width = const_expr( + max( + t.element_type.width + for t in (mX, mW, mB, mO, mRstd, mMean) + if t is not None + ) + ) + # Match Quack's unified RMSNorm/LayerNorm kernel: pick vecsize based on + # the widest dtype participating in the op (e.g. fp32 weights => fp16 + # X uses 64b vectorization). + vecsize = math.gcd(self.N, 128 // largest_dtype_width) + default_copy_bits_x = vecsize * self.dtype.width + num_copy_bits_x = ( + int(self.copy_bits_x) + if self.copy_bits_x is not None + else default_copy_bits_x + ) + tiler_mn, tv_layout = self._get_tv_layout(num_copy_bits=num_copy_bits_x) + num_threads = ( + cute.size(tv_layout, mode=[0]) + if _KERNEL_ACCEPTS_LAYOUT_ARGS + else self._get_num_threads() + ) + num_warps = num_threads // cute.arch.WARP_SIZE + + # Expand weight / bias to match tiler_mn[0] rows per CTA. + mW = cute.make_tensor( + mW.iterator, + cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,))), + ) + if const_expr(mB is not None): + mB = cute.make_tensor( + mB.iterator, + cute.prepend(mB.layout, cute.make_layout((tiler_mn[0],), stride=(0,))), + ) + if const_expr(mRstd is not None): + mRstd = cute.make_tensor( + mRstd.iterator, + cute.append(mRstd.layout, cute.make_layout((self.N,), stride=(0,))), + ) + if const_expr(mMean is not None): + mMean = cute.make_tensor( + mMean.iterator, + cute.append(mMean.layout, cute.make_layout((self.N,), stride=(0,))), + ) + + kernel = ( + self.kernel( + mX, + mW, + mB, + mO, + mRstd, + mMean, + eps, + tv_layout, + tiler_mn, + const_expr(self.reload_from), + const_expr(self.delay_w_load), + ) + if _KERNEL_ACCEPTS_LAYOUT_ARGS + else self.kernel( + mX, + mW, + mB, + mO, + mRstd, + mMean, + eps, + ) + ) + kernel.launch( + grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1], + block=[num_threads, 1, 1], + cluster=[ + 1, + self.cluster_n, + 1, + ] + if const_expr(self.cluster_n > 1) + else None, + smem=self._smem_size_in_bytes(tiler_mn, num_warps), + stream=stream, + ) + + @cute.jit + def launch_from_ptrs( + self, + ptr_x: cute.Pointer, + ptr_w: cute.Pointer, + ptr_b: Optional[cute.Pointer], + ptr_out: cute.Pointer, + ptr_rstd: Optional[cute.Pointer], + ptr_mean: Optional[cute.Pointer], + M: Int32, + ld: Int32, + stream: cuda.CUstream, + eps: Float32 = 1e-6, + ) -> None: + """Pointer-based entrypoint that bypasses DLPack conversions. + + This reconstructs cute.Tensor views from raw device pointers + explicit + layouts inside the JIT graph, reusing the tuned LayerNormSM100 schedule. + """ + # Mirror Quack-style divisibility contracts so the compiler can prove + # alignment for vectorized loads/stores (and cp.async when enabled). + divby = ( + int(self.copy_bits_x) // self.dtype.width + if const_expr(self.copy_bits_x is not None) + else (128 // self.dtype.width) + ) + ld_assumed = cute.assume(ld, divby=divby) + # Match `mark_compact_shape_dynamic(mode=0, ...)`: M is dynamic, N is static. + layout_mn = cute.make_layout((M, self.N), stride=(ld_assumed, 1)) + layout_n = cute.make_layout((self.N,), stride=(1,)) layout_m = cute.make_layout((M,), stride=(1,)) mX = cute.make_tensor(ptr_x, layout_mn) @@ -1186,7 +2454,9 @@ def layernorm( ) stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) - key = (N, dtype, mB is not None, mRstd is not None, mMean is not None) + dtype_w = TORCH2CUTE_DTYPE[weight.dtype] + dtype_b = TORCH2CUTE_DTYPE[bias.dtype] if bias is not None else None + key = (N, dtype, dtype_w, dtype_b, mRstd is not None, mMean is not None) compiled = _COMPILE_CACHE.get(key) if compiled is None: op = LayerNormSM100(dtype, N) @@ -1226,8 +2496,9 @@ def layernorm( def _can_use_ptr_path(x: Tensor, weight: Tensor, bias: Optional[Tensor]) -> bool: """Return True if we can safely use the pointer-based fast path. - This is intentionally conservative: we target the common inference-like - layout (2D row-major with stride(1)==1) and Quack-style fp32 weights. + This path supports both Quack-style fp32 weights/bias and same-dtype + weights/bias for bf16/fp16 activations, as long as the layout stays in the + common row-major form. """ if not x.is_cuda or x.dim() != 2: return False @@ -1235,26 +2506,29 @@ def _can_use_ptr_path(x: Tensor, weight: Tensor, bias: Optional[Tensor]) -> bool return False if not weight.is_cuda or weight.dim() != 1: return False - if weight.dtype != torch.float32: - return False + if weight.dtype != x.dtype: + if weight.dtype != torch.float32: + return False + if x.dtype not in (torch.float16, torch.bfloat16): + return False if not weight.is_contiguous(): return False if bias is not None: if not bias.is_cuda or bias.dim() != 1: return False - if bias.dtype != torch.float32: + if bias.dtype != weight.dtype: return False if not bias.is_contiguous(): return False - # Require 16B alignment for 128-bit vector copies (matches Quack's assumed_align=16). + # Require 16B alignment for vectorized loads/stores. if (x.data_ptr() % 16) != 0: return False if (weight.data_ptr() % 16) != 0: return False if bias is not None and (bias.data_ptr() % 16) != 0: return False - # The kernel uses 128-bit vectorized loads; require the leading dimension - # to preserve 16B alignment for every row start. + # The kernel uses vectorized loads; require the leading dimension to + # preserve 16B alignment for every row start. dtype_x = TORCH2CUTE_DTYPE[x.dtype] divby = 128 // dtype_x.width if (x.stride(0) % divby) != 0: @@ -1292,6 +2566,8 @@ def _layernorm_forward_ptr_into( stream = cuda.CUstream(stream_handle) dtype_x = TORCH2CUTE_DTYPE[x.dtype] + dtype_w = TORCH2CUTE_DTYPE[weight.dtype] + dtype_b = TORCH2CUTE_DTYPE[bias.dtype] if bias is not None else None # Keep the pointer path aligned with Quack's LayerNorm schedule: # - <=128b vectorization (cp.async-compatible) # - shared-memory staging for X (gmem->smem->rmem) to amortize global latency @@ -1327,6 +2603,8 @@ def _layernorm_forward_ptr_into( "ptr", int(N), dtype_x, + dtype_w, + dtype_b, bias is not None, rstd is not None, mean is not None, @@ -1362,14 +2640,14 @@ def _layernorm_forward_ptr_into( assumed_align=assumed_align_xo, ) ptr_w = rt.make_ptr( - cutlass.Float32, + dtype_w, weight.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16, ) ptr_b = ( rt.make_ptr( - cutlass.Float32, + dtype_b, bias.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16, @@ -1417,6 +2695,8 @@ def _layernorm_forward_ptr_into( compiled=compiled, N=int(N), dtype_x=dtype_x, + dtype_w=dtype_w, + dtype_b=dtype_b, has_bias=bias is not None, has_rstd=rstd is not None, has_mean=mean is not None, @@ -1453,14 +2733,14 @@ def _layernorm_forward_ptr_into( assumed_align=assumed_align_xo, ) ptr_w = rt.make_ptr( - cutlass.Float32, + dtype_w, weight.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16, ) ptr_b = ( rt.make_ptr( - cutlass.Float32, + dtype_b, bias.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16, @@ -1740,136 +3020,729 @@ def _layernorm_backward_param( ) -def _layernorm_backward_dx_sm100( +def _layernorm_backward_dx_sm100( + dout_2d: Tensor, + x_2d: Tensor, + weight: Tensor, + rstd_1d: Tensor, + mean_1d: Tensor, + dx_2d: Tensor, +) -> None: + """ + Host-side helper to run the dx-only LayerNorm backward kernel. + """ + M, N = x_2d.shape + assert dout_2d.shape == (M, N) + assert rstd_1d.numel() == M + assert mean_1d.numel() == M + + dtype = TORCH2CUTE_DTYPE[x_2d.dtype] + + mX = _convert_row_major(x_2d) + mdO = _convert_row_major(dout_2d) + mdX = _convert_row_major(dx_2d) + + mW = convert_from_dlpack_cute( + weight.detach(), + leading_dim=0, + alignment=16, + divisibility=128 // cutlass.Float32.width, + ) + mRstd = from_dlpack(rstd_1d.detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=0 + ) + mMean = from_dlpack(mean_1d.detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=0 + ) + + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + dtype_w = TORCH2CUTE_DTYPE[weight.dtype] + dtype_dout = TORCH2CUTE_DTYPE[dout_2d.dtype] + key = (N, dtype, dtype_w, dtype_dout) + compiled = _BWD_DX_COMPILE_CACHE.get(key) + if compiled is None: + compiled = cute.compile( + _layernorm_backward_dx, + mX, + mW, + mdO, + mRstd, + mMean, + mdX, + stream, + ) + _BWD_DX_COMPILE_CACHE[key] = compiled + + compiled( + mX, + mW, + mdO, + mRstd, + mMean, + mdX, + stream, + ) + + +def _layernorm_backward_params_sm100( + dout_2d: Tensor, + x_2d: Tensor, + rstd_1d: Tensor, + mean_1d: Tensor, + dw_partial: Optional[Tensor], + db_partial: Optional[Tensor], + sm_count: int, +) -> None: + """ + Host-side helper to run the parameter-gradient kernel that populates + dw_partial / db_partial of shape (sm_count, N). + """ + M, N = x_2d.shape + assert dout_2d.shape == (M, N) + assert rstd_1d.numel() == M + assert mean_1d.numel() == M + if dw_partial is None and db_partial is None: + return + + dtype = TORCH2CUTE_DTYPE[x_2d.dtype] + + mX = _convert_row_major(x_2d) + mdO = _convert_row_major(dout_2d) + mRstd = from_dlpack(rstd_1d.detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=0 + ) + mMean = from_dlpack(mean_1d.detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=0 + ) + + mdW_partial = ( + from_dlpack(dw_partial, assumed_align=16).mark_compact_shape_dynamic(mode=0) + if dw_partial is not None + else None + ) + mdB_partial = ( + from_dlpack(db_partial, assumed_align=16).mark_compact_shape_dynamic(mode=0) + if db_partial is not None + else None + ) + + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + has_bias = db_partial is not None + dtype_dout = TORCH2CUTE_DTYPE[dout_2d.dtype] + key = (N, dtype, dtype_dout, has_bias) + compiled = _BWD_PARAM_COMPILE_CACHE.get(key) + if compiled is None: + compiled = cute.compile( + _layernorm_backward_param, + mX, + mdO, + mRstd, + mMean, + mdW_partial, + mdB_partial, + Int32(sm_count), + stream, + ) + _BWD_PARAM_COMPILE_CACHE[key] = compiled + + compiled( + mX, + mdO, + mRstd, + mMean, + mdW_partial, + mdB_partial, + Int32(sm_count), + stream, + ) + + +def _can_use_ptr_path_bwd( + x: Tensor, + weight: Tensor, + dout: Tensor, + rstd: Tensor, + mean: Tensor, +) -> bool: + """Return whether tensors satisfy the optimized pointer-path contract. + + The pointer kernel assumes row-major 2D activations/gradients, contiguous + fp32 per-row statistics, 16B activation alignment, and enough hidden-dim + divisibility for vectorized/cp.async copies. Keep this predicate strict: + callers that fail it are still handled by the local fallback kernels. + """ + if not x.is_cuda or not dout.is_cuda or x.dim() != 2 or dout.shape != x.shape: + return False + if dout.dtype != x.dtype: + return False + if x.stride(1) != 1 or dout.stride(1) != 1: + return False + if dout.stride(0) != x.stride(0): + return False + if not weight.is_cuda or weight.dim() != 1 or weight.shape[0] != x.shape[1]: + return False + if weight.dtype != x.dtype: + if weight.dtype != torch.float32: + return False + if x.dtype not in (torch.float16, torch.bfloat16): + return False + if not weight.is_contiguous(): + return False + if (x.data_ptr() % 16) != 0 or (dout.data_ptr() % 16) != 0: + return False + assumed_align_w = 32 if weight.dtype == torch.float32 else 16 + if (weight.data_ptr() % assumed_align_w) != 0: + return False + if (rstd.data_ptr() % 4) != 0 or (mean.data_ptr() % 4) != 0: + return False + if (not rstd.is_cuda) or (not mean.is_cuda): + return False + if rstd.dtype != torch.float32 or mean.dtype != torch.float32: + return False + if (not rstd.is_contiguous()) or (not mean.is_contiguous()): + return False + if rstd.numel() != x.shape[0] or mean.numel() != x.shape[0]: + return False + dtype_x = TORCH2CUTE_DTYPE[x.dtype] + divby = 256 // dtype_x.width + if (x.stride(0) % divby) != 0: + return False + # The CuTe tiled-copy path is vectorized along N; requiring a multiple of 8 + # for fp16/bf16 covers the DSv3/DSv4 hidden sizes while avoiding a slower + # predicated tail specialization in the hot pointer path. + if (x.shape[1] % 8) != 0: + return False + return True + + +def _get_layernorm_bwd_sm_count( + N: int, + device: torch.device, + *, + M: Optional[int] = None, + dtype: Optional[torch.dtype] = None, +) -> int: + sm_count_multiple = ( + 16 + if N <= 256 + else (8 if N <= 1024 else (4 if N <= 2048 else (2 if N <= 4096 else 1))) + ) + props = torch.cuda.get_device_properties(device) + sm_count = props.multi_processor_count + sm_count = ( + sm_count * sm_count_multiple + if N <= 8192 + else sm_count // 2 + if N <= 16384 + else sm_count * 2 + ) + # The self-contained SM103 LayerNorm backward kernel is register/smem limited + # to about two resident CTAs per SM for N=8192. Large-M shapes need both + # slots populated, while small-M shapes are more launch-tail sensitive. + if ( + props.major == 10 + and props.minor == 3 + and props.multi_processor_count == 152 + and dtype in (torch.float16, torch.bfloat16) + and N == 8192 + ): + sm_count = ( + props.multi_processor_count + if M is not None and M <= 4096 + else props.multi_processor_count * 2 + ) + return int(sm_count) + + +def _get_layernorm_bwd_tuning( + N: int, + dtype_x: type[cutlass.Numeric], +) -> tuple[Optional[int], Optional[int], Optional[int]]: + tpr_override: Optional[int] = None + nt_override: Optional[int] = None + cluster_n_override: Optional[int] = None + + if N == 4096 and dtype_x.width == 16: + # On GB300, widening the persistent grid is already enough for 4k hidden + # sizes; a 128-thread CTA / 128 threads-per-row schedule keeps the fused + # kernel exact for the larger DSv3 LayerNorm backward cases. + tpr_override = 128 + nt_override = 128 + elif N == 8192 and dtype_x.width == 16: + # One row per CTA, with a full 256-thread row reduction to reduce per-thread + # register work on the wide 8k DSv3 rows. + tpr_override = 256 + nt_override = 256 + cluster_n_override = 1 + + return tpr_override, nt_override, cluster_n_override + + +def _should_use_layernorm_bwd_ptr(x: Tensor, weight: Tensor) -> bool: + """Return True only for shapes where the pointer bwd path is a stable win.""" + N = int(x.shape[-1]) + M = int(x.numel() // N) + if x.dtype not in (torch.float16, torch.bfloat16): + return False + + props = torch.cuda.get_device_properties(x.device) + is_gb300 = ( + props.major == 10 and props.minor == 3 and props.multi_processor_count == 152 + ) + + # DSv3/DSv4 LayerNorm backward target on GB300. Route the wide bf16 rows + # through Oink's self-contained fused pointer path only when the stricter + # layout/alignment checks pass; all other shapes use local fallback kernels. + if is_gb300 and N in (6144, 7168, 8192): + return weight.dtype == x.dtype + if N == 4096 and M >= 8192 and is_gb300: + return weight.dtype in (x.dtype, torch.float32) + return False + + +def _should_use_layernorm_bwd_atomic_dw_ptr(x: Tensor, weight: Tensor) -> bool: + """Use atomic dW only where it wins over partial-buffer reduction on GB300.""" + if x.dtype not in (torch.float16, torch.bfloat16) or weight.dtype != x.dtype: + return False + props = torch.cuda.get_device_properties(x.device) + is_gb300 = ( + props.major == 10 and props.minor == 3 and props.multi_processor_count == 152 + ) + if not is_gb300: + return False + N = int(x.shape[-1]) + M = int(x.numel() // N) + if N == 8192: + return M >= 16384 + if N == 6144: + return M >= 65536 + return False + + +def _layernorm_backward_atomic_ptr( + *, + dout_2d: Tensor, + x_2d: Tensor, + weight: Tensor, + rstd_1d: Tensor, + mean_1d: Tensor, + dx_2d: Tensor, + dw_acc: Tensor, + db_acc: Optional[Tensor], + sm_count: int, +) -> None: + """Run the pointer kernel with direct atomic dW accumulation. + + dB remains a regular (sm_count, N) partial buffer when requested; only dW + uses the direct atomic accumulation variant. + """ + assert _LayerNormBackwardSM103 is not None + assert _can_use_ptr_path_bwd(x_2d, weight, dout_2d, rstd_1d, mean_1d) + + M, N = x_2d.shape + device_index = x_2d.get_device() + if torch.cuda.current_device() != device_index: + torch.cuda.set_device(device_index) + + dtype_x = TORCH2CUTE_DTYPE[x_2d.dtype] + dtype_w = TORCH2CUTE_DTYPE[weight.dtype] + assumed_align_x = 16 + assumed_align_w = 32 if weight.dtype == torch.float32 else 16 + assumed_align_dw = 32 + + stream_handle = int(torch.cuda.current_stream().cuda_stream) + stream = cuda.CUstream(stream_handle) + ld_val = int(x_2d.stride(0)) + + tpr_override, nt_override, cluster_n_override = _get_layernorm_bwd_tuning( + N, dtype_x + ) + + use_atomic_dw = dw_acc.dim() == 1 + assert db_acc is None or db_acc.dim() == 2 + + key = ( + "layernorm_bwd_atomic_ptr", + int(N), + dtype_x, + dtype_w, + bool(db_acc is not None), + bool(use_atomic_dw), + int(assumed_align_x), + int(assumed_align_w), + int(assumed_align_dw), + tpr_override, + nt_override, + cluster_n_override, + int(device_index), + ) + compiled = _BWD_PTR_COMPILE_CACHE.get(key) + if compiled is None: + op = _LayerNormBackwardSM103(dtype_x, N) + op.atomic_dw = bool(use_atomic_dw) + if tpr_override is not None: + op._tpr_override = tpr_override # type: ignore[attr-defined] + if nt_override is not None: + op._nt_override = nt_override # type: ignore[attr-defined] + if cluster_n_override is not None: + op._cluster_n_override = cluster_n_override # type: ignore[attr-defined] + + ptr_x = rt.make_ptr( + dtype_x, + x_2d.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align_x, + ) + ptr_w = rt.make_ptr( + dtype_w, + weight.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align_w, + ) + ptr_dout = rt.make_ptr( + dtype_x, + dout_2d.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align_x, + ) + ptr_rstd = rt.make_ptr( + cutlass.Float32, + rstd_1d.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=4, + ) + ptr_mean = rt.make_ptr( + cutlass.Float32, + mean_1d.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=4, + ) + ptr_dx = rt.make_ptr( + dtype_x, + dx_2d.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align_x, + ) + ptr_dw = rt.make_ptr( + cutlass.Float32, + dw_acc.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align_dw, + ) + ptr_db = ( + rt.make_ptr( + cutlass.Float32, + db_acc.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align_dw, + ) + if db_acc is not None + else None + ) + compiled = cute.compile( + op.launch_from_ptrs, + ptr_x, + ptr_w, + ptr_dout, + ptr_rstd, + ptr_mean, + ptr_dx, + ptr_dw, + ptr_db, + Int32(M), + Int32(ld_val), + Int32(int(sm_count)), + stream, + ) + _BWD_PTR_COMPILE_CACHE[key] = compiled + + ptr_x = rt.make_ptr( + dtype_x, + x_2d.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align_x, + ) + ptr_w = rt.make_ptr( + dtype_w, + weight.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align_w, + ) + ptr_dout = rt.make_ptr( + dtype_x, + dout_2d.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align_x, + ) + ptr_rstd = rt.make_ptr( + cutlass.Float32, + rstd_1d.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=4, + ) + ptr_mean = rt.make_ptr( + cutlass.Float32, + mean_1d.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=4, + ) + ptr_dx = rt.make_ptr( + dtype_x, + dx_2d.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align_x, + ) + ptr_dw = rt.make_ptr( + cutlass.Float32, + dw_acc.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align_dw, + ) + ptr_db = ( + rt.make_ptr( + cutlass.Float32, + db_acc.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align_dw, + ) + if db_acc is not None + else None + ) + compiled( + ptr_x, + ptr_w, + ptr_dout, + ptr_rstd, + ptr_mean, + ptr_dx, + ptr_dw, + ptr_db, + Int32(M), + Int32(ld_val), + Int32(int(sm_count)), + stream, + ) + + +def _layernorm_backward_ptr( + *, dout_2d: Tensor, x_2d: Tensor, weight: Tensor, rstd_1d: Tensor, mean_1d: Tensor, dx_2d: Tensor, + dw_partial: Tensor, + db_partial: Optional[Tensor], + sm_count: int, ) -> None: - """ - Host-side helper to run the dx-only LayerNorm backward kernel. - """ + assert _LayerNormBackwardSM103 is not None + assert _can_use_ptr_path_bwd(x_2d, weight, dout_2d, rstd_1d, mean_1d) + M, N = x_2d.shape - assert dout_2d.shape == (M, N) - assert rstd_1d.numel() == M - assert mean_1d.numel() == M + device_index = x_2d.get_device() + if torch.cuda.current_device() != device_index: + torch.cuda.set_device(device_index) - dtype = TORCH2CUTE_DTYPE[x_2d.dtype] + dtype_x = TORCH2CUTE_DTYPE[x_2d.dtype] + dtype_w = TORCH2CUTE_DTYPE[weight.dtype] + assumed_align_x = 16 + assumed_align_w = 32 if weight.dtype == torch.float32 else 16 + assumed_align_dw = 32 - mX = _convert_row_major(x_2d) - mdO = _convert_row_major(dout_2d) - mdX = _convert_row_major(dx_2d) + stream_handle = int(torch.cuda.current_stream().cuda_stream) + stream = cuda.CUstream(stream_handle) + ld_val = int(x_2d.stride(0)) - mW = convert_from_dlpack_cute( - weight.detach(), - leading_dim=0, - alignment=16, - divisibility=128 // cutlass.Float32.width, - ) - mRstd = from_dlpack(rstd_1d.detach(), assumed_align=4).mark_layout_dynamic( - leading_dim=0 - ) - mMean = from_dlpack(mean_1d.detach(), assumed_align=4).mark_layout_dynamic( - leading_dim=0 + tpr_override, nt_override, cluster_n_override = _get_layernorm_bwd_tuning( + N, dtype_x ) - stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) - key = (N, dtype) - compiled = _BWD_DX_COMPILE_CACHE.get(key) + key = ( + "layernorm_bwd_ptr", + int(N), + dtype_x, + dtype_w, + bool(db_partial is not None), + int(assumed_align_x), + int(assumed_align_w), + int(assumed_align_dw), + tpr_override, + nt_override, + cluster_n_override, + int(device_index), + ) + compiled = _BWD_PTR_COMPILE_CACHE.get(key) if compiled is None: + op = _LayerNormBackwardSM103(dtype_x, N) + if tpr_override is not None: + op._tpr_override = tpr_override # type: ignore[attr-defined] + if nt_override is not None: + op._nt_override = nt_override # type: ignore[attr-defined] + if cluster_n_override is not None: + op._cluster_n_override = cluster_n_override # type: ignore[attr-defined] + + ptr_x = rt.make_ptr( + dtype_x, + x_2d.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align_x, + ) + ptr_w = rt.make_ptr( + dtype_w, + weight.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align_w, + ) + ptr_dout = rt.make_ptr( + dtype_x, + dout_2d.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align_x, + ) + ptr_rstd = rt.make_ptr( + cutlass.Float32, + rstd_1d.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=4, + ) + ptr_mean = rt.make_ptr( + cutlass.Float32, + mean_1d.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=4, + ) + ptr_dx = rt.make_ptr( + dtype_x, + dx_2d.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align_x, + ) + ptr_dw = rt.make_ptr( + cutlass.Float32, + dw_partial.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align_dw, + ) + ptr_db = ( + rt.make_ptr( + cutlass.Float32, + db_partial.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align_dw, + ) + if db_partial is not None + else None + ) compiled = cute.compile( - _layernorm_backward_dx, - mX, - mW, - mdO, - mRstd, - mMean, - mdX, + op.launch_from_ptrs, + ptr_x, + ptr_w, + ptr_dout, + ptr_rstd, + ptr_mean, + ptr_dx, + ptr_dw, + ptr_db, + Int32(M), + Int32(ld_val), + Int32(int(sm_count)), stream, ) - _BWD_DX_COMPILE_CACHE[key] = compiled + _BWD_PTR_COMPILE_CACHE[key] = compiled - compiled( - mX, - mW, - mdO, - mRstd, - mMean, - mdX, - stream, + launcher = _get_fast_ptr_layernorm_bwd_launcher( + compiled=compiled, + dtype=dtype_x, + weight_dtype=dtype_w, + N=N, + device_index=device_index, + stream_handle=stream_handle, + has_db_partial=db_partial is not None, + assumed_align_x=assumed_align_x, + assumed_align_w=assumed_align_w, + assumed_align_dw=assumed_align_dw, ) - - -def _layernorm_backward_params_sm100( - dout_2d: Tensor, - x_2d: Tensor, - rstd_1d: Tensor, - mean_1d: Tensor, - dw_partial: Optional[Tensor], - db_partial: Optional[Tensor], - sm_count: int, -) -> None: - """ - Host-side helper to run the parameter-gradient kernel that populates - dw_partial / db_partial of shape (sm_count, N). - """ - M, N = x_2d.shape - assert dout_2d.shape == (M, N) - assert rstd_1d.numel() == M - assert mean_1d.numel() == M - if dw_partial is None and db_partial is None: + if launcher is not None: + launcher.launch( + x=x_2d, + weight=weight, + dout=dout_2d, + rstd=rstd_1d, + mean=mean_1d, + dx=dx_2d, + dw_partial=dw_partial, + db_partial=db_partial, + M=M, + ld=ld_val, + sm_count=int(sm_count), + ) return - dtype = TORCH2CUTE_DTYPE[x_2d.dtype] - - mX = _convert_row_major(x_2d) - mdO = _convert_row_major(dout_2d) - mRstd = from_dlpack(rstd_1d.detach(), assumed_align=4).mark_layout_dynamic( - leading_dim=0 + ptr_x = rt.make_ptr( + dtype_x, + x_2d.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align_x, ) - mMean = from_dlpack(mean_1d.detach(), assumed_align=4).mark_layout_dynamic( - leading_dim=0 + ptr_w = rt.make_ptr( + dtype_w, + weight.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align_w, ) - - mdW_partial = ( - from_dlpack(dw_partial, assumed_align=16).mark_compact_shape_dynamic(mode=0) - if dw_partial is not None - else None + ptr_dout = rt.make_ptr( + dtype_x, + dout_2d.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align_x, ) - mdB_partial = ( - from_dlpack(db_partial, assumed_align=16).mark_compact_shape_dynamic(mode=0) + ptr_rstd = rt.make_ptr( + cutlass.Float32, + rstd_1d.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=4, + ) + ptr_mean = rt.make_ptr( + cutlass.Float32, + mean_1d.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=4, + ) + ptr_dx = rt.make_ptr( + dtype_x, + dx_2d.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align_x, + ) + ptr_dw = rt.make_ptr( + cutlass.Float32, + dw_partial.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align_dw, + ) + ptr_db = ( + rt.make_ptr( + cutlass.Float32, + db_partial.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align_dw, + ) if db_partial is not None else None ) - - stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) - has_bias = db_partial is not None - key = (N, dtype, has_bias) - compiled = _BWD_PARAM_COMPILE_CACHE.get(key) - if compiled is None: - compiled = cute.compile( - _layernorm_backward_param, - mX, - mdO, - mRstd, - mMean, - mdW_partial, - mdB_partial, - Int32(sm_count), - stream, - ) - _BWD_PARAM_COMPILE_CACHE[key] = compiled - compiled( - mX, - mdO, - mRstd, - mMean, - mdW_partial, - mdB_partial, - Int32(sm_count), + ptr_x, + ptr_w, + ptr_dout, + ptr_rstd, + ptr_mean, + ptr_dx, + ptr_dw, + ptr_db, + Int32(M), + Int32(ld_val), + Int32(int(sm_count)), stream, ) @@ -1883,28 +3756,110 @@ def layernorm_backward( bias: Optional[Tensor] = None, ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: """ - LayerNorm backward implemented in CuteDSL / CUTLASS. + LayerNorm backward implemented with CuTeDSL / CUTLASS-backed kernels. - Computes gradients w.r.t. input, weight, and optional bias using - two kernels: - - A dx kernel (CTA-per-row) that streams over N. - - A parameter-gradient kernel that accumulates dw/db over a - persistent grid of CTAs across the M dimension. + Tuned GB300 shapes use Oink's self-contained pointer fast-launch fused + LayerNorm backward schedule. Other shapes fall back to the local + self-contained dx and parameter-gradient kernels. """ assert x.shape == dout.shape, "x and dout must have the same shape" assert x.is_cuda and dout.is_cuda, "x and dout must be CUDA tensors" assert weight.dim() == 1, "weight must be 1D" + assert weight.shape[0] == x.shape[-1], "weight shape must match hidden dim" if bias is not None: assert bias.dim() == 1, "bias must be 1D" + assert bias.shape == weight.shape, "bias must match weight shape" + + use_ptr_path = ( + _LayerNormBackwardSM103 is not None and _should_use_layernorm_bwd_ptr(x, weight) + ) x_2d, orig_shape = _as_2d(x) dout_2d, _ = _as_2d(dout) M, N = x_2d.shape - - # Flatten to 2D for the kernels. mean_flat = mean.view(M) rstd_flat = rstd.view(M) + if use_ptr_path and _can_use_ptr_path_bwd( + x_2d, weight, dout_2d, rstd_flat, mean_flat + ): + device = x.device + sm_count = _get_layernorm_bwd_sm_count(N, device, M=M, dtype=x.dtype) + stream_handle = int(torch.cuda.current_stream(device=device).cuda_stream) + dx_2d = torch.empty_like(x_2d) + use_atomic_dw = _should_use_layernorm_bwd_atomic_dw_ptr(x, weight) + + if use_atomic_dw: + dw_atomic = _get_layernorm_bwd_atomic_dw_workspace( + device_index=x.get_device(), + stream_handle=stream_handle, + N=N, + ) + dw_atomic.zero_() + db_partial = ( + _get_layernorm_bwd_workspace( + device_index=x.get_device(), + stream_handle=stream_handle, + sm_count=int(sm_count), + N=N, + has_bias=True, + )[1] + if bias is not None + else None + ) + _layernorm_backward_atomic_ptr( + dout_2d=dout_2d, + x_2d=x_2d, + weight=weight, + rstd_1d=rstd_flat, + mean_1d=mean_flat, + dx_2d=dx_2d, + dw_acc=dw_atomic, + db_acc=db_partial, + sm_count=int(sm_count), + ) + dweight = _finalize_layernorm_bwd_atomic_dw(dw_acc=dw_atomic, weight=weight) + if bias is not None: + assert db_partial is not None + dbias = _reduce_partial_sum_fp32( + db_partial, device_index=bias.get_device() + ) + if bias.dtype != torch.float32: + dbias = dbias.to(bias.dtype) + else: + dbias = None + else: + dw_partial, db_partial = _get_layernorm_bwd_workspace( + device_index=x.get_device(), + stream_handle=stream_handle, + sm_count=int(sm_count), + N=N, + has_bias=bias is not None, + ) + _layernorm_backward_ptr( + dout_2d=dout_2d, + x_2d=x_2d, + weight=weight, + rstd_1d=rstd_flat, + mean_1d=mean_flat, + dx_2d=dx_2d, + dw_partial=dw_partial, + db_partial=db_partial, + sm_count=int(sm_count), + ) + # Keep the post-kernel reduction exact. Experimental custom reducers + # can change dW rounding, so use the trusted fp32 partial-sum finalizer. + dweight, dbias = _finalize_layernorm_bwd_partials( + dw_partial=dw_partial, + db_partial=db_partial, + weight=weight, + bias=bias, + device=device, + ) + dx = _restore_shape(dx_2d, orig_shape) + return dx, dweight, dbias + + # Flatten to 2D for the local fallback kernels. dx_2d = torch.empty_like(x_2d) _layernorm_backward_dx_sm100( dout_2d,