From 5a3083f19d85c2bb3c582fb91baa1493bbdd9d45 Mon Sep 17 00:00:00 2001 From: alphakhaw Date: Sun, 17 May 2026 17:01:14 +0800 Subject: [PATCH 01/16] hyena_se: fix vendored imports and Triton 3.3 compat --- vortex/ops/hyena_se/interface.py | 4 ++-- vortex/ops/hyena_se/kernel_utils.py | 16 ++++++++++------ vortex/ops/hyena_se/ref_bwd.py | 2 +- vortex/ops/hyena_se/ref_fwd.py | 2 +- 4 files changed, 14 insertions(+), 10 deletions(-) diff --git a/vortex/ops/hyena_se/interface.py b/vortex/ops/hyena_se/interface.py index 56235f6..2dd3369 100644 --- a/vortex/ops/hyena_se/interface.py +++ b/vortex/ops/hyena_se/interface.py @@ -11,7 +11,7 @@ from .fwd import two_pass_fwd_grouped, two_pass_fwd_grouped_refactor logger = logging.getLogger(__name__) -from hyena_ops.kernel_utils import ( +from .kernel_utils import ( BwdKernelConfig, FwdKernelConfig, ) @@ -21,7 +21,7 @@ except ImportError: old_causal_conv1d_fn = None -from hyena_ops.utils import toeplitz +from .utils import toeplitz @dataclass(eq=False) diff --git a/vortex/ops/hyena_se/kernel_utils.py b/vortex/ops/hyena_se/kernel_utils.py index 6107c43..29833a2 100644 --- a/vortex/ops/hyena_se/kernel_utils.py +++ b/vortex/ops/hyena_se/kernel_utils.py @@ -133,7 +133,7 @@ def is_power_of_2(n): @triton.jit -def get_program_ids(pid, tiles_per_seq, d_tiles_per_chunk, chunks_per_seq, SWIZZLE: tl.constexpr = "row"): +def get_program_ids(pid, tiles_per_seq, d_tiles_per_chunk, chunks_per_seq, SWIZZLE: tl.constexpr = 0): """ Converts 1-D program id to 3-D grid along batch, chunk (sequence) and d (feature) dimensions. @@ -142,18 +142,22 @@ def get_program_ids(pid, tiles_per_seq, d_tiles_per_chunk, chunks_per_seq, SWIZZ tiles_per_seq: number of tiles along sequence dimension d_tiles_per_chunk: number of tiles along d dimension chunks_per_seq: number of chunks along sequence dimension - SWIZZLE: "row" or "col", the threadblock launch order, where axis=0 corresponds to + SWIZZLE: 0/"row" or 1/"col", the threadblock launch order, where axis=0 corresponds to sequence dim and axis=1 to feature dimension - - "row" - row major tile order, where blocks are launched feature dimension then sequence dimension - - "col" - column major tile order, where blocks are launched sequence dimension then feature dimension + - 0 / "row" - row major tile order, blocks launched feature dimension then sequence dimension + - 1 / "col" - column major tile order, blocks launched sequence dimension then feature dimension NOTE: "col" should be more L2-cache friendly when grouping, since the same filter is used for each feature chunk + Triton >=3.3 rejects a Python str as a tl.constexpr *default* value + ("'str' object has no attribute 'type'" at the call site). The default + is therefore an int; the body still matches the "row"/"col" strings so + callers forwarding a string constexpr are unchanged. """ - if SWIZZLE == "row": + if SWIZZLE == 0 or SWIZZLE == "row": pid_batch = pid // tiles_per_seq pid_d = pid % d_tiles_per_chunk pid_chunk = (pid // d_tiles_per_chunk) % chunks_per_seq - elif SWIZZLE == "col": + elif SWIZZLE == 1 or SWIZZLE == "col": pid_batch = pid // tiles_per_seq pid_chunk = pid % chunks_per_seq pid_d = (pid // chunks_per_seq) % d_tiles_per_chunk diff --git a/vortex/ops/hyena_se/ref_bwd.py b/vortex/ops/hyena_se/ref_bwd.py index d956e30..5143548 100644 --- a/vortex/ops/hyena_se/ref_bwd.py +++ b/vortex/ops/hyena_se/ref_bwd.py @@ -1,5 +1,5 @@ import torch -from hyena_ops.utils import correction_toeplitz, toeplitz +from .utils import correction_toeplitz, toeplitz def dh_local(dT_local, g, hl, chunk_size): diff --git a/vortex/ops/hyena_se/ref_fwd.py b/vortex/ops/hyena_se/ref_fwd.py index 33d6d8e..4a0aad2 100644 --- a/vortex/ops/hyena_se/ref_fwd.py +++ b/vortex/ops/hyena_se/ref_fwd.py @@ -21,7 +21,7 @@ from einops import rearrange -from hyena_ops.utils import correction_toeplitz, toeplitz +from .utils import correction_toeplitz, toeplitz def gcg_fwd_ref_original(x, B, C, h, use_causal_conv=False): From a2a9a8729d6122e953767b077b5aa78992538e85 Mon Sep 17 00:00:00 2001 From: alphakhaw Date: Tue, 19 May 2026 18:56:27 +0800 Subject: [PATCH 02/16] feat: add fused Triton HCS depthwise-conv kernel First of the HC{S,M,L} inference kernels for vortex (refs #16, #76): a depthwise causal conv plus the gated hcs_conv adapter. --- vortex/ops/hcs_interface.py | 184 ++++++++++++++++++++++++++++++++++++ 1 file changed, 184 insertions(+) diff --git a/vortex/ops/hcs_interface.py b/vortex/ops/hcs_interface.py index e69de29..78b75b6 100644 --- a/vortex/ops/hcs_interface.py +++ b/vortex/ops/hcs_interface.py @@ -0,0 +1,184 @@ +# pyright: reportAttributeAccessIssue=none +""" +HCS -- Hyena Cascade Short. + +A from-scratch Triton depthwise causal 1D convolution for the short-filter +(fir_length < 128) gated branch of HyenaInferenceEngine.parallel_fir. + +The convolution is the only time-mixing op in an HCS layer: a depthwise +filter of fir_length taps (7 in evo2_7b) applied per channel. This module +provides the @triton.jit kernel and a thin Python launcher; the hcs_conv +adapter that wires it behind the use_hcs_kernel config flag is added alongside. +""" + +from typing import Callable + +import torch +import triton +import triton.language as tl + +# Autotuned search space for the conv kernel's register-tile sizes. Triton +# benchmarks these once per (D, L, FIR_LEN) and caches the winner, so no +# single tile size is hard-coded -- the GPU and shape pick it. +_AUTOTUNE_CONFIGS: list[triton.Config] = [ + triton.Config({"BLOCK_D": 32, "BLOCK_L": 64}, num_warps=2), + triton.Config({"BLOCK_D": 64, "BLOCK_L": 64}, num_warps=4), + triton.Config({"BLOCK_D": 64, "BLOCK_L": 128}, num_warps=4), + triton.Config({"BLOCK_D": 128, "BLOCK_L": 64}, num_warps=4), + triton.Config({"BLOCK_D": 64, "BLOCK_L": 256}, num_warps=8), + triton.Config({"BLOCK_D": 128, "BLOCK_L": 128}, num_warps=8), +] + + +@triton.autotune(configs=_AUTOTUNE_CONFIGS, key=["D", "L", "FIR_LEN"]) +@triton.jit +def _hcs_depthwise_conv_kernel( + u_ptr, + w_ptr, + z_ptr, + D, + L, + stride_ub, + stride_ud, + stride_ul, + stride_wd, + stride_wk, + FIR_LEN: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_L: tl.constexpr, +): + """ + Depthwise causal conv: z[b, d, t] = sum_k w[d, k] * u[b, d, t - FIR_LEN + 1 + k]. + + One program covers a (BLOCK_D, BLOCK_L) tile of one batch element. The + FIR_LEN tap loop is unrolled at compile time. Input positions before 0 + are masked to zero, giving a causal (left-padded) convolution. + """ + pid_b = tl.program_id(0) + pid_d = tl.program_id(1) + pid_l = tl.program_id(2) + + offs_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) + offs_l = pid_l * BLOCK_L + tl.arange(0, BLOCK_L) + mask_d = offs_d < D + mask_l = offs_l < L + + u_base = u_ptr + pid_b * stride_ub + offs_d[:, None] * stride_ud + acc = tl.zeros((BLOCK_D, BLOCK_L), dtype=tl.float32) + + for k in tl.static_range(FIR_LEN): + w_k = tl.load( + w_ptr + offs_d * stride_wd + k * stride_wk, mask=mask_d, other=0.0 + ) + pos = offs_l - (FIR_LEN - 1) + k + mask_pos = mask_d[:, None] & (pos[None, :] >= 0) & (pos[None, :] < L) + u_tile = tl.load(u_base + pos[None, :] * stride_ul, mask=mask_pos, other=0.0) + acc += w_k[:, None].to(tl.float32) * u_tile.to(tl.float32) + + z_ptrs = ( + z_ptr + + pid_b * stride_ub + + offs_d[:, None] * stride_ud + + offs_l[None, :] * stride_ul + ) + tl.store(z_ptrs, acc, mask=mask_d[:, None] & mask_l[None, :]) + + +def hcs_depthwise_conv(u: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: + """ + Depthwise causal 1D convolution, the HCS short-filter time-mixing op. + + Equivalent to F.conv1d(u, weight, padding=fir_length - 1, groups=D) + trimmed to length L, but in a single fused Triton launch. + + Args: + u (torch.Tensor): Input activations, shape (B, D, L), contiguous. + weight (torch.Tensor): Depthwise filter, shape (D, 1, fir_length), + contiguous. Every channel has its own filter. + + Returns: + torch.Tensor: Convolved output, shape (B, D, L), same dtype as u. + """ + if not u.is_contiguous(): + u = u.contiguous() + if not weight.is_contiguous(): + weight = weight.contiguous() + if u.dim() != 3 or weight.dim() != 3: + raise ValueError(f"expected 3-D u and weight, got {u.shape} and {weight.shape}") + + B, D, L = u.shape + Dw, in_per_group, fir_length = weight.shape + if Dw != D or in_per_group != 1: + raise ValueError(f"weight {tuple(weight.shape)} is not depthwise for D={D}") + + z: torch.Tensor = torch.empty_like(u) + # BLOCK_D / BLOCK_L are supplied by @triton.autotune; the grid is a + # callable so it can read the chosen tile sizes from the winning config. + grid: Callable[[triton.Config], tuple[int, int, int]] = lambda meta: ( + B, + triton.cdiv(D, meta["BLOCK_D"]), + triton.cdiv(L, meta["BLOCK_L"]), + ) + _hcs_depthwise_conv_kernel[grid]( + u, + weight, + z, + D, + L, + u.stride(0), + u.stride(1), + u.stride(2), + weight.stride(0), + weight.stride(2), + FIR_LEN=fir_length, + ) + return z + + +def hcs_conv( + x1: torch.Tensor, + x2: torch.Tensor, + v: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None = None, + *, + gated_bias: bool = False, + padding_mask: torch.Tensor | None = None, +) -> torch.Tensor: + """ + Fully-gated HCS short conv: z = x2 * (conv(x1 * v, weight) + bias). + + Drop-in replacement for the gated fir_length < 128 branch of + HyenaInferenceEngine.parallel_fir. It reproduces that branch exactly: + the depthwise conv runs in fp32 for numerical parity with the F.conv1d + path, the result is cast back to the activation dtype, bias-added, + masked, then closed with the post-gate multiply by x2. + + Args: + x1 (torch.Tensor): Pre-gate "key" stream, shape (B, D, L). + x2 (torch.Tensor): Post-gate stream, shape (B, D, L). + v (torch.Tensor): "Value" stream, shape (B, D, L). + weight (torch.Tensor): Depthwise filter, shape (D, 1, fir_length). + bias (torch.Tensor | None): Per-channel skip-gain, shape (D,). + gated_bias (bool): If True the bias is applied multiplicatively + (bias * x1 * v); HCS uses additive bias (False). + padding_mask (torch.Tensor | None): If a tensor, zeros masked + positions after the conv, shape (B, L). + + Returns: + torch.Tensor: Gated HCS output, shape (B, D, L), x1's dtype. + """ + u: torch.Tensor = x1 * v + z: torch.Tensor = hcs_depthwise_conv(u=u.float(), weight=weight.float()) + z = z.to(u.dtype) + + if bias is not None: + if gated_bias: + z = z + bias[None, :, None] * u + else: + z = z + bias[None, :, None] + + if isinstance(padding_mask, torch.Tensor): + z = z * padding_mask[:, None] + + return x2 * z From 3b55461998c93bfab1a4fd58e5e525d28e01bdad Mon Sep 17 00:00:00 2001 From: alphakhaw Date: Tue, 19 May 2026 18:56:27 +0800 Subject: [PATCH 03/16] feat: dispatch the HCS kernel behind the use_hcs_kernel flag parallel_fir routes the gated short-filter branch through the Triton HCS kernel when use_hcs_kernel is set; default off is a byte-identical no-op. refs #16, #76 --- vortex/model/engine.py | 22 ++++++++++++++++++---- vortex/model/model.py | 1 + 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/vortex/model/engine.py b/vortex/model/engine.py index 3a6ec35..3b5bca0 100644 --- a/vortex/model/engine.py +++ b/vortex/model/engine.py @@ -5,13 +5,17 @@ import torch import torch.nn.functional as F -try: - pass -except: - pass from vortex.model.utils import column_split from vortex.logging import activations_logger +# vortex-kernels: optional fused Triton HCS kernel (refs #16, #76). hcs_interface +# imports triton -- a Linux/GPU-only optional dependency -- so this import is +# guarded; `import vortex` must still succeed where triton is not installed. +try: + from vortex.ops.hcs_interface import hcs_conv +except ImportError: + hcs_conv = None + IIR_PREFILL_MODES = [ "recurrence", "modal-fft", @@ -118,6 +122,7 @@ def __init__( ground_truth_activations_path=None, print_activations=False, hyena_flip_x1x2=False, + use_hcs_kernel=False, ) -> None: self.fir_fn = fir_fn assert iir_prefill_style in IIR_PREFILL_MODES, f"iir_prefill_style must be one of {IIR_PREFILL_MODES}" @@ -127,6 +132,7 @@ def __init__( self.ground_truth_activations_path = ground_truth_activations_path self.print_activations = print_activations self.hyena_flip_x1x2 = hyena_flip_x1x2 + self.use_hcs_kernel = use_hcs_kernel def parallel_fir( self, @@ -156,6 +162,14 @@ def parallel_fir( x2, x1, v = u.split([hidden_size, hidden_size, hidden_size], dim=1) if self.hyena_flip_x1x2: x1, x2 = x2, x1 + + # vortex-kernels: opt-in fused Triton HCS short conv (refs #16, #76). + # Matches only the gated short-filter cascade; flag off is a no-op. + if self.use_hcs_kernel and hcs_conv is not None and fir_length < 128 and groups: + z = hcs_conv(x1, x2, v, weight, bias, gated_bias=gated_bias, padding_mask=padding_mask) + fir_state = (x1 * v)[..., -fir_length + 1 :] if inference_params is not None else None + return z, fir_state + u = x1 * v if self.print_activations: diff --git a/vortex/model/model.py b/vortex/model/model.py index e9ddec5..dba0d66 100644 --- a/vortex/model/model.py +++ b/vortex/model/model.py @@ -157,6 +157,7 @@ def __init__(self, config, layer_idx, hyena_filter_groups=None, fir_inner_filter ground_truth_activations_path=self.ground_truth_activations_path, print_activations=self.print_activations, hyena_flip_x1x2=config.get("hyena_flip_x1x2", False), + use_hcs_kernel=config.get("use_hcs_kernel", False), ) self.use_flash_depthwise = config.get("use_flash_depthwise", False) self.data_dtype = None From 81b8aa96933574132bdeced0f7edf0574559c6cc Mon Sep 17 00:00:00 2001 From: alphakhaw Date: Tue, 19 May 2026 18:56:27 +0800 Subject: [PATCH 04/16] test: add HCS kernel, adapter, engine, and e2e tests refs #16, #76 --- test/test_hcs_adapter.py | 114 +++++++++++++++++++++++++++++++ test/test_hcs_e2e.py | 102 ++++++++++++++++++++++++++++ test/test_hcs_engine.py | 140 +++++++++++++++++++++++++++++++++++++++ test/test_hcs_kernel.py | 87 ++++++++++++++++++++++++ 4 files changed, 443 insertions(+) create mode 100644 test/test_hcs_adapter.py create mode 100644 test/test_hcs_e2e.py create mode 100644 test/test_hcs_engine.py create mode 100644 test/test_hcs_kernel.py diff --git a/test/test_hcs_adapter.py b/test/test_hcs_adapter.py new file mode 100644 index 0000000..bdf9cbc --- /dev/null +++ b/test/test_hcs_adapter.py @@ -0,0 +1,114 @@ +""" +Tests for the hcs_conv adapter -- the drop-in for the gated HCS branch of +HyenaInferenceEngine.parallel_fir. + +The oracle is a pure-PyTorch transcription of that branch (engine.py, the +fir_length < 128 path plus the gate). hcs_conv must reproduce it exactly, +so the use_hcs_kernel flag off vs on is a behavioural no-op. +""" + +import pytest +import torch +import torch.nn.functional as F + +from vortex.ops.hcs_interface import hcs_conv + +CUDA: bool = torch.cuda.is_available() + + +def _hcs_branch_ref(x1, x2, v, weight, bias, gated_bias, padding_mask): + """ + Pure-PyTorch reference for the gated HCS branch of parallel_fir. + """ + D, L = v.shape[1], v.shape[2] + u = x1 * v + z = F.conv1d( + u.float(), + weight.float(), + bias=None, + stride=1, + padding=weight.shape[-1] - 1, + groups=D, + )[..., :L] + z = z.to(u.dtype) + if bias is not None: + z = z + bias[None, :, None] * u if gated_bias else z + bias[None, :, None] + if isinstance(padding_mask, torch.Tensor): + z = z * padding_mask[:, None] + return x2 * z + + +@pytest.mark.skipif(not CUDA, reason="hcs_conv requires CUDA") +@pytest.mark.parametrize("L", [1024, 8192]) +@pytest.mark.parametrize("with_bias", [True, False]) +def test_hcs_conv_matches_engine_branch(L: int, with_bias: bool) -> None: + """ + hcs_conv reproduces the flag-off parallel_fir HCS branch in fp32. + """ + torch.manual_seed(0) + B, D, K = 1, 4096, 7 + x1 = torch.randn(B, D, L, dtype=torch.float32, device="cuda") + x2 = torch.randn_like(x1) + v = torch.randn_like(x1) + weight = torch.randn(D, 1, K, dtype=torch.float32, device="cuda") + bias = torch.randn(D, dtype=torch.float32, device="cuda") if with_bias else None + + z = hcs_conv(x1, x2, v, weight, bias) + z_ref = _hcs_branch_ref(x1, x2, v, weight, bias, False, None) + assert (z - z_ref).abs().max().item() < 1e-3 + + +@pytest.mark.skipif(not CUDA, reason="hcs_conv requires CUDA") +def test_hcs_conv_gated_bias() -> None: + """ + The gated_bias=True path applies the bias multiplicatively. + """ + torch.manual_seed(0) + B, D, L, K = 1, 4096, 2048, 7 + x1 = torch.randn(B, D, L, dtype=torch.float32, device="cuda") + x2 = torch.randn_like(x1) + v = torch.randn_like(x1) + weight = torch.randn(D, 1, K, dtype=torch.float32, device="cuda") + bias = torch.randn(D, dtype=torch.float32, device="cuda") + + z = hcs_conv(x1, x2, v, weight, bias, gated_bias=True) + z_ref = _hcs_branch_ref(x1, x2, v, weight, bias, True, None) + assert (z - z_ref).abs().max().item() < 1e-3 + + +@pytest.mark.skipif(not CUDA, reason="hcs_conv requires CUDA") +def test_hcs_conv_padding_mask() -> None: + """ + A padding_mask tensor zeros masked positions in the output. + """ + torch.manual_seed(0) + B, D, L, K = 1, 256, 1024, 7 + x1 = torch.randn(B, D, L, dtype=torch.float32, device="cuda") + x2 = torch.randn_like(x1) + v = torch.randn_like(x1) + weight = torch.randn(D, 1, K, dtype=torch.float32, device="cuda") + mask = torch.ones(B, L, dtype=torch.float32, device="cuda") + mask[:, L // 2 :] = 0.0 + + z = hcs_conv(x1, x2, v, weight, None, padding_mask=mask) + z_ref = _hcs_branch_ref(x1, x2, v, weight, None, False, mask) + assert (z - z_ref).abs().max().item() < 1e-3 + assert z[..., L // 2 :].abs().max().item() == 0.0 + + +@pytest.mark.skipif(not CUDA, reason="hcs_conv requires CUDA") +def test_hcs_conv_bf16() -> None: + """ + hcs_conv matches the engine branch in bf16, the real inference dtype. + """ + torch.manual_seed(0) + B, D, L, K = 1, 4096, 2048, 7 + x1 = torch.randn(B, D, L, dtype=torch.bfloat16, device="cuda") + x2 = torch.randn_like(x1) + v = torch.randn_like(x1) + weight = torch.randn(D, 1, K, dtype=torch.bfloat16, device="cuda") + bias = torch.randn(D, dtype=torch.bfloat16, device="cuda") + + z = hcs_conv(x1, x2, v, weight, bias) + z_ref = _hcs_branch_ref(x1, x2, v, weight, bias, False, None) + torch.testing.assert_close(z, z_ref, rtol=2e-2, atol=2e-2) diff --git a/test/test_hcs_e2e.py b/test/test_hcs_e2e.py new file mode 100644 index 0000000..0a07409 --- /dev/null +++ b/test/test_hcs_e2e.py @@ -0,0 +1,102 @@ +""" +End-to-end test for the use_hcs_kernel flag inside a real Evo2 model. + +Loads an Evo2 checkpoint once and runs the same forward twice -- with the +HCS kernel off (stock F.conv1d) and on (the fused Triton HCS kernel). The +flag changes only the HCS conv implementation, so the logits must agree. + +The model id defaults to evo2_7b (the variant cached on the dev box) and is +overridable with VK_E2E_MODEL. The checkpoint is large; the test skips +cleanly if it cannot be loaded. +""" + +import os + +import pytest +import torch + +from vortex.model.engine import HyenaInferenceEngine + +CUDA: bool = torch.cuda.is_available() +_MODEL_ID: str = os.environ.get("VK_E2E_MODEL", "evo2_7b") +_SEQ_LEN: int = 2048 + + +@pytest.fixture(scope="module") +def evo2_model(): + """ + Load the Evo2 model once for the module, or skip if unavailable. + """ + if not CUDA: + pytest.skip("Evo2 e2e test requires CUDA") + try: + from evo2 import Evo2 + except ImportError as exc: # pragma: no cover - depends on optional dep + pytest.skip(f"evo2 not installed: {exc}") + try: + return Evo2(_MODEL_ID) + except Exception as exc: # noqa: BLE001 - any load failure -> skip, not fail + pytest.skip(f"could not load Evo2({_MODEL_ID!r}): {exc}") + + +def _set_hcs_kernel(model, enabled: bool) -> int: + """ + Flip use_hcs_kernel on every HyenaInferenceEngine reachable from model. + + The flag is read off the engine instance (normally set from the + use_hcs_kernel config key at build time); toggling it on the loaded + model lets one checkpoint serve both halves of the comparison. + + Args: + model: A loaded Evo2 model. + enabled (bool): Target value for use_hcs_kernel. + + Returns: + The number of HyenaInferenceEngine instances touched. + """ + root = getattr(model, "model", model) + touched = 0 + for module in root.modules(): + engine = getattr(module, "engine", None) + if isinstance(engine, HyenaInferenceEngine): + engine.use_hcs_kernel = enabled + touched += 1 + return touched + + +def _logits(model, input_ids: torch.Tensor) -> torch.Tensor: + """ + Run a forward pass and return the logits tensor as fp32. + + Evo2 wraps its output in nested tuples, so descend to the first tensor. + """ + with torch.no_grad(): + out = model(input_ids) + while isinstance(out, (tuple, list)): + out = out[0] + return out.float() + + +@pytest.mark.skipif(not CUDA, reason="Evo2 e2e test requires CUDA") +def test_vk_hcs_e2e_matches_baseline(evo2_model) -> None: + """ + A full Evo2 forward is unchanged when use_hcs_kernel swaps in the kernel. + """ + torch.manual_seed(0) + input_ids = torch.randint(1, 5, (1, _SEQ_LEN), dtype=torch.int, device="cuda:0") + + try: + touched = _set_hcs_kernel(evo2_model, False) + assert touched > 0, "no HyenaInferenceEngine found in the Evo2 model" + logits_off = _logits(evo2_model, input_ids) + + _set_hcs_kernel(evo2_model, True) + logits_on = _logits(evo2_model, input_ids) + finally: + _set_hcs_kernel(evo2_model, False) + + assert logits_on.shape == logits_off.shape + max_diff = (logits_on - logits_off).abs().max().item() + assert max_diff < 5e-2, ( + f"use_hcs_kernel changed the logits: max_diff={max_diff:.3e}" + ) diff --git a/test/test_hcs_engine.py b/test/test_hcs_engine.py new file mode 100644 index 0000000..291e579 --- /dev/null +++ b/test/test_hcs_engine.py @@ -0,0 +1,140 @@ +""" +Wiring tests for the use_hcs_kernel branch in HyenaInferenceEngine.parallel_fir. + +These exercise the engine branch + hcs_conv adapter + Triton kernel together +by calling parallel_fir directly (no model weights needed). With the kernel +enabled, the gated short-filter cascade must reproduce the stock F.conv1d +path; the branch must not fire for the featurizer or the HCM-length cascade. +""" + +import pytest +import torch +import torch.nn.functional as F + +from vortex.model.engine import HyenaInferenceEngine + +CUDA: bool = torch.cuda.is_available() + +# evo2_7b HCS cascade shapes: D=4096, hcs_filter_groups=256, fir_length=7. +B, D, K, GROUPS = 1, 4096, 7, 256 +DIMS: tuple[int, int, int, int, int] = (D, 32, D // 32, 16, GROUPS) +_CASCADE_KW: dict[str, bool | int | None] = dict( + groups=GROUPS, + gated_bias=False, + column_split_hyena=False, + dim_last=False, + fir_length=K, + gate=True, +) + + +def _hcs_inputs(L: int, dtype: torch.dtype): + """ + Build (u, weight, bias) for an HCS cascade parallel_fir call. + """ + u = torch.randn(B, 3 * D, L, dtype=dtype, device="cuda") + weight = torch.randn(D, 1, K, dtype=dtype, device="cuda") + bias = torch.randn(D, dtype=dtype, device="cuda") + return u, weight, bias + + +@pytest.mark.skipif(not CUDA, reason="HCS kernel requires CUDA") +@pytest.mark.parametrize("L", [1024, 8192]) +def test_vk_hcs_on_matches_baseline_fp32(L: int) -> None: + """ + use_hcs_kernel on reproduces the stock parallel_fir HCS output in fp32. + """ + torch.manual_seed(0) + engine = HyenaInferenceEngine(layer_idx=0) + u, weight, bias = _hcs_inputs(L, torch.float32) + + z_off, state_off = engine.parallel_fir( + F.conv1d, u, weight, bias, L, DIMS, **_CASCADE_KW + ) + + engine.use_hcs_kernel = True + z_on, state_on = engine.parallel_fir( + F.conv1d, u, weight, bias, L, DIMS, **_CASCADE_KW + ) + + assert state_off is None and state_on is None + assert z_on.shape == z_off.shape == (B, D, L) + assert (z_on - z_off).abs().max().item() < 1e-3 + + +@pytest.mark.skipif(not CUDA, reason="HCS kernel requires CUDA") +def test_vk_hcs_on_matches_baseline_bf16() -> None: + """ + use_hcs_kernel on reproduces the stock HCS output in bf16, the inference dtype. + """ + torch.manual_seed(0) + engine = HyenaInferenceEngine(layer_idx=0) + u, weight, bias = _hcs_inputs(4096, torch.bfloat16) + + z_off, _ = engine.parallel_fir(F.conv1d, u, weight, bias, 4096, DIMS, **_CASCADE_KW) + + engine.use_hcs_kernel = True + z_on, _ = engine.parallel_fir(F.conv1d, u, weight, bias, 4096, DIMS, **_CASCADE_KW) + + torch.testing.assert_close(z_on, z_off, rtol=2e-2, atol=2e-2) + + +@pytest.mark.skipif(not CUDA, reason="HCS kernel requires CUDA") +def test_vk_hcs_off_by_default() -> None: + """ + A fresh HyenaInferenceEngine has use_hcs_kernel False, so parallel_fir + takes the stock path with no behavioural change. + """ + torch.manual_seed(0) + engine = HyenaInferenceEngine(layer_idx=0) + assert engine.use_hcs_kernel is False + + u, weight, bias = _hcs_inputs(2048, torch.float32) + z_default, _ = engine.parallel_fir( + F.conv1d, u, weight, bias, 2048, DIMS, **_CASCADE_KW + ) + + explicit_off = HyenaInferenceEngine(layer_idx=0, use_hcs_kernel=False) + z_explicit, _ = explicit_off.parallel_fir( + F.conv1d, u, weight, bias, 2048, DIMS, **_CASCADE_KW + ) + + # both took the stock path, so they are bitwise identical + assert (z_default - z_explicit).abs().max().item() == 0.0 + + +@pytest.mark.skipif(not CUDA, reason="HCS kernel requires CUDA") +def test_vk_hcs_predicate_skips_non_hcs_calls() -> None: + """ + The branch matches only the gated short cascade -- not the featurizer + (gate=False) and not the HCM-length cascade (fir_length >= 128). + """ + torch.manual_seed(0) + engine = HyenaInferenceEngine(layer_idx=0, use_hcs_kernel=True) + + # featurizer: gate=False, dim_last=True, no split -> branch unreachable + u_feat = torch.randn(B, 1024, 3 * D, dtype=torch.float32, device="cuda") + w_feat = torch.randn(3 * D, 1, 3, dtype=torch.float32, device="cuda") + z_feat, _ = engine.parallel_fir( + F.conv1d, + u_feat, + w_feat, + None, + 1024, + DIMS, + groups=None, + dim_last=True, + fir_length=3, + gate=False, + ) + assert z_feat.shape == (B, 3 * D, 1024) + + # HCM-length cascade: fir_length=128 fails the `< 128` predicate, so the + # call still routes through the stock fftconv_func path. + u_hcm = torch.randn(B, 3 * D, 1024, dtype=torch.float32, device="cuda") + w_hcm = torch.randn(D, 1, 128, dtype=torch.float32, device="cuda") + b_hcm = torch.randn(D, dtype=torch.float32, device="cuda") + kw = dict(_CASCADE_KW) + kw["fir_length"] = 128 + z_hcm, _ = engine.parallel_fir(F.conv1d, u_hcm, w_hcm, b_hcm, 1024, DIMS, **kw) + assert z_hcm.shape == (B, D, 1024) diff --git a/test/test_hcs_kernel.py b/test/test_hcs_kernel.py new file mode 100644 index 0000000..9f67979 --- /dev/null +++ b/test/test_hcs_kernel.py @@ -0,0 +1,87 @@ +""" +Correctness tests for the HCS depthwise causal conv Triton kernel. + +The oracle is F.conv1d on the depthwise causal path (left-pad by +fir_length - 1, trim to L) -- exactly the convolution +HyenaInferenceEngine.parallel_fir applies in its fir_length < 128 branch. +""" + +import pytest +import torch +import torch.nn.functional as F + +from vortex.ops.hcs_interface import hcs_depthwise_conv + +CUDA: bool = torch.cuda.is_available() + + +def _conv1d_ref(u: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: + """ + Reference depthwise causal conv via F.conv1d, trimmed to length L. + """ + fir_length = weight.shape[-1] + L = u.shape[-1] + return F.conv1d( + u, weight, bias=None, stride=1, padding=fir_length - 1, groups=u.shape[1] + )[..., :L] + + +@pytest.mark.skipif(not CUDA, reason="HCS Triton kernel requires CUDA") +@pytest.mark.parametrize("L", [1024, 8192, 32768]) +@pytest.mark.parametrize("fir_length", [3, 7]) +def test_hcs_conv_matches_conv1d(L: int, fir_length: int) -> None: + """ + The HCS Triton depthwise conv matches F.conv1d at evo2_7b shapes. + """ + torch.manual_seed(0) + B, D = 1, 4096 + u = torch.randn(B, D, L, dtype=torch.float32, device="cuda") + weight = torch.randn(D, 1, fir_length, dtype=torch.float32, device="cuda") + + z = hcs_depthwise_conv(u, weight) + z_ref = _conv1d_ref(u, weight) + + max_diff = (z - z_ref).abs().max().item() + mean_diff = (z - z_ref).abs().mean().item() + assert max_diff < 1e-3, f"max_diff={max_diff:.2e}" + assert mean_diff < 1e-4, f"mean_diff={mean_diff:.2e}" + + +@pytest.mark.skipif(not CUDA, reason="HCS Triton kernel requires CUDA") +def test_hcs_conv_is_causal() -> None: + """ + Output position t depends only on inputs at or before t. + """ + torch.manual_seed(0) + B, D, L, fir_length = 1, 64, 256, 7 + u = torch.randn(B, D, L, dtype=torch.float32, device="cuda") + weight = torch.randn(D, 1, fir_length, dtype=torch.float32, device="cuda") + + z = hcs_depthwise_conv(u, weight) + u_perturbed = u.clone() + u_perturbed[..., L // 2] += 100.0 + z_perturbed = hcs_depthwise_conv(u_perturbed, weight) + + # positions strictly before L//2 must be unchanged + before = (z[..., : L // 2] - z_perturbed[..., : L // 2]).abs().max().item() + after = (z[..., L // 2 :] - z_perturbed[..., L // 2 :]).abs().max().item() + assert before == 0.0, ( + f"non-causal: positions before the perturbation changed by {before:.2e}" + ) + assert after > 0.0, "perturbation had no effect on later positions" + + +@pytest.mark.skipif(not CUDA, reason="HCS Triton kernel requires CUDA") +@pytest.mark.parametrize("D", [16, 4096, 4100]) +def test_hcs_conv_ragged_channels(D: int) -> None: + """ + The kernel masks channel tiles that do not divide BLOCK_D evenly. + """ + torch.manual_seed(0) + B, L, fir_length = 1, 1024, 7 + u = torch.randn(B, D, L, dtype=torch.float32, device="cuda") + weight = torch.randn(D, 1, fir_length, dtype=torch.float32, device="cuda") + + z = hcs_depthwise_conv(u, weight) + z_ref = _conv1d_ref(u, weight) + assert (z - z_ref).abs().max().item() < 1e-3 From 4264f7d0caf97a8161010ce556c71dcf7c490b9c Mon Sep 17 00:00:00 2001 From: alphakhaw Date: Tue, 19 May 2026 21:27:07 +0800 Subject: [PATCH 05/16] feat: add the HCM FFT-conv complex-multiply kernel --- test/test_hcm_kernel.py | 69 ++++++++++++++++++ vortex/ops/hcm_interface.py | 140 ++++++++++++++++++++++++++++++++++++ 2 files changed, 209 insertions(+) create mode 100644 test/test_hcm_kernel.py diff --git a/test/test_hcm_kernel.py b/test/test_hcm_kernel.py new file mode 100644 index 0000000..5e44e21 --- /dev/null +++ b/test/test_hcm_kernel.py @@ -0,0 +1,69 @@ +""" +Correctness tests for the HCM fused complex-multiply Triton kernel. + +_hcm_complex_mul computes stage 3 of the HCM FFT-conv -- the elementwise +product of the activation and filter spectra -- with stage 1's 1/fft_size +filter normalisation folded in. The oracle is the explicit torch expression +fftconv_func evaluates: u_f * (rfft(k) / fft_size). +""" + +import pytest +import torch + +from vortex.model.engine import adjust_filter_shape_for_broadcast +from vortex.ops.hcm_interface import _hcm_complex_mul + +CUDA: bool = torch.cuda.is_available() + + +def _spectra( + B: int, D: int, L: int, fir_length: int +) -> tuple[torch.Tensor, torch.Tensor, int]: + """ + Build the activation and filter spectra exactly as fftconv_func does. + + Returns: + tuple[torch.Tensor, torch.Tensor, int]: u_f, the activation spectrum + of shape (B, D, F); k_f, the unscaled filter spectrum of shape + (1, D, F) broadcast over the batch; and the integer fft_size. + """ + torch.manual_seed(0) + fft_size = 2 * L + u = torch.randn(B, D, L, dtype=torch.float32, device="cuda") + k = torch.randn(D, fir_length, dtype=torch.float32, device="cuda") + u_f = torch.fft.rfft(u, n=fft_size) + k_f = adjust_filter_shape_for_broadcast(u, torch.fft.rfft(k, n=fft_size)) + return u_f, k_f, fft_size + + +@pytest.mark.skipif(not CUDA, reason="HCM Triton kernel requires CUDA") +@pytest.mark.parametrize("B", [1, 2]) +@pytest.mark.parametrize("L", [1024, 8192, 32768]) +def test_hcm_complex_mul_matches_oracle(B: int, L: int) -> None: + """ + The fused complex multiply matches u_f * (k_f / fft_size) at evo2_7b + shapes; B=2 exercises the filter broadcast over the batch. + """ + u_f, k_f, fft_size = _spectra(B, D=4096, L=L, fir_length=128) + + y_f = _hcm_complex_mul(u_f, k_f, fft_size) + y_f_ref = u_f * (k_f / fft_size) + + assert y_f.shape == u_f.shape + assert y_f.dtype == u_f.dtype + max_diff = (y_f - y_f_ref).abs().max().item() + mean_diff = (y_f - y_f_ref).abs().mean().item() + assert max_diff < 1e-4, f"max_diff={max_diff:.2e}" + assert mean_diff < 1e-5, f"mean_diff={mean_diff:.2e}" + + +@pytest.mark.skipif(not CUDA, reason="HCM Triton kernel requires CUDA") +def test_hcm_complex_mul_masks_ragged_tail() -> None: + """ + The kernel masks the flat (D, F) tail when D*F does not divide the tile. + """ + u_f, k_f, fft_size = _spectra(B=1, D=16, L=100, fir_length=4) + + y_f = _hcm_complex_mul(u_f, k_f, fft_size) + y_f_ref = u_f * (k_f / fft_size) + assert (y_f - y_f_ref).abs().max().item() < 1e-4 diff --git a/vortex/ops/hcm_interface.py b/vortex/ops/hcm_interface.py index e69de29..44764d0 100644 --- a/vortex/ops/hcm_interface.py +++ b/vortex/ops/hcm_interface.py @@ -0,0 +1,140 @@ +# pyright: reportAttributeAccessIssue=none +""" +HCM -- Hyena Cascade Medium. + +Fused Triton epilogues for the FFT-convolution path of +HyenaInferenceEngine.parallel_fir (the fir_length >= 128 branch). At a +128-tap filter Triton cannot out-write cuFFT for the transforms themselves, +so the win is launch-count: the elementwise glue around the three cuFFT +calls is fused into Triton kernels. + +This module currently provides _hcm_complex_mul -- stage 3 of fftconv_func, +the broadcast complex product u_f * k_f, with stage 1's 1/fft_size filter +normalisation folded in. +""" + +from typing import Callable + +import torch +import triton +import triton.language as tl + +# Autotuned search space for the flat (D, F) tile. The op is memory-bound +# elementwise, so the winner is whichever BLOCK best saturates bandwidth for +# a given problem size -- Triton benchmarks these once per DF and caches it. +_AUTOTUNE_CONFIGS: list[triton.Config] = [ + triton.Config({"BLOCK": 256}, num_warps=2), + triton.Config({"BLOCK": 512}, num_warps=4), + triton.Config({"BLOCK": 1024}, num_warps=4), + triton.Config({"BLOCK": 2048}, num_warps=8), +] + + +@triton.autotune(configs=_AUTOTUNE_CONFIGS, key=["DF"]) +@triton.jit +def _hcm_complex_mul_kernel( + u_ptr, + k_ptr, + y_ptr, + DF, + inv_fft_size, + stride_batch, + BLOCK: tl.constexpr, +): + """ + Broadcast complex multiply: y[b] = u_f[b] * k_f[0] * inv_fft_size. + + One program covers a BLOCK-element slice of the flattened (D, F) plane + of one batch element. Complex values are stored interleaved (real, imag) + -- view_as_real's layout -- so element n's real part is at offset 2n and + its imag part at 2n + 1. The filter k_f carries no batch stride: every + batch element multiplies against the same spectrum. + + The flat (D, F) tile is grid axis 0: cdiv(DF, BLOCK) overruns the 65535 + cap on axes 1 and 2 at long context, so the small batch sits on axis 1. + """ + pid = tl.program_id(0) + pid_b = tl.program_id(1) + + offs = pid * BLOCK + tl.arange(0, BLOCK) + mask = offs < DF + + u_base = u_ptr + pid_b * stride_batch + y_base = y_ptr + pid_b * stride_batch + + u_re = tl.load(u_base + 2 * offs, mask=mask, other=0.0) + u_im = tl.load(u_base + 2 * offs + 1, mask=mask, other=0.0) + k_re = tl.load(k_ptr + 2 * offs, mask=mask, other=0.0) + k_im = tl.load(k_ptr + 2 * offs + 1, mask=mask, other=0.0) + + y_re = (u_re * k_re - u_im * k_im) * inv_fft_size + y_im = (u_re * k_im + u_im * k_re) * inv_fft_size + + tl.store(y_base + 2 * offs, y_re, mask=mask) + tl.store(y_base + 2 * offs + 1, y_im, mask=mask) + + +def _hcm_complex_mul( + u_f: torch.Tensor, k_f: torch.Tensor, fft_size: int +) -> torch.Tensor: + """ + Fused broadcast complex multiply with the 1/fft_size filter scale. + + Computes u_f * k_f / fft_size -- stage 3 of fftconv_func, with stage 1's + filter normalisation folded in. u_f is the activation spectrum; k_f is + the *unscaled* filter spectrum, already shaped for broadcast over the + batch by adjust_filter_shape_for_broadcast. + + Args: + u_f (torch.Tensor): Activation spectrum, complex, shape (B, D, F). + k_f (torch.Tensor): Filter spectrum, complex, shape (1, D, F), shared + across the batch and not yet scaled by 1/fft_size. + fft_size (int): The FFT length n = 2 * seqlen; its reciprocal folds + in as the filter normalisation. + + Returns: + torch.Tensor: The scaled product u_f * k_f / fft_size, complex, shape + (B, D, F), u_f's dtype. + + Raises: + ValueError: If the tensors are not 3-D complex, or k_f is not + broadcastable over the batch of u_f. + """ + if u_f.dim() != 3 or k_f.dim() != 3: + raise ValueError( + f"expected 3-D u_f and k_f, got {tuple(u_f.shape)} and {tuple(k_f.shape)}" + ) + if not u_f.is_complex() or not k_f.is_complex(): + raise ValueError("u_f and k_f must be complex tensors") + + B, D, F = u_f.shape + if tuple(k_f.shape) != (1, D, F): + raise ValueError( + f"k_f {tuple(k_f.shape)} is not broadcastable over u_f {tuple(u_f.shape)}" + ) + + u_f = u_f.contiguous() + k_f = k_f.contiguous() + y_f: torch.Tensor = torch.empty_like(u_f) + + # Triton has no complex dtype: operate on the (..., 2) real/imag view. + u_r = torch.view_as_real(u_f) + k_r = torch.view_as_real(k_f) + y_r = torch.view_as_real(y_f) + + DF = D * F + # BLOCK is supplied by @triton.autotune; the grid is a callable so it can + # read the chosen tile size from the winning config. + grid: Callable[[triton.Config], tuple[int, int]] = lambda meta: ( + triton.cdiv(DF, meta["BLOCK"]), + B, + ) + _hcm_complex_mul_kernel[grid]( + u_r, + k_r, + y_r, + DF, + 1.0 / fft_size, + u_r.stride(0), + ) + return y_f From b656d4248b3799230494070fb064bced14034677 Mon Sep 17 00:00:00 2001 From: alphakhaw Date: Tue, 19 May 2026 21:54:09 +0800 Subject: [PATCH 06/16] feat: add the HCM bias-residual Triton kernel --- test/test_hcm_kernel.py | 54 +++++++++++++-- vortex/ops/hcm_interface.py | 128 +++++++++++++++++++++++++++++++++--- 2 files changed, 168 insertions(+), 14 deletions(-) diff --git a/test/test_hcm_kernel.py b/test/test_hcm_kernel.py index 5e44e21..35cb7f0 100644 --- a/test/test_hcm_kernel.py +++ b/test/test_hcm_kernel.py @@ -1,17 +1,18 @@ """ -Correctness tests for the HCM fused complex-multiply Triton kernel. +Correctness tests for the HCM fused Triton kernels. -_hcm_complex_mul computes stage 3 of the HCM FFT-conv -- the elementwise -product of the activation and filter spectra -- with stage 1's 1/fft_size -filter normalisation folded in. The oracle is the explicit torch expression -fftconv_func evaluates: u_f * (rfft(k) / fft_size). +_hcm_complex_mul is stage 3 of the HCM FFT-conv -- the elementwise product +of the activation and filter spectra, with stage 1's 1/fft_size filter +normalisation folded in. _hcm_bias_residual is stage 5 -- the skip-residual +add y + u * bias[:, None], written out at u's dtype. Each kernel is checked +against the explicit torch expression fftconv_func evaluates. """ import pytest import torch from vortex.model.engine import adjust_filter_shape_for_broadcast -from vortex.ops.hcm_interface import _hcm_complex_mul +from vortex.ops.hcm_interface import _hcm_bias_residual, _hcm_complex_mul CUDA: bool = torch.cuda.is_available() @@ -67,3 +68,44 @@ def test_hcm_complex_mul_masks_ragged_tail() -> None: y_f = _hcm_complex_mul(u_f, k_f, fft_size) y_f_ref = u_f * (k_f / fft_size) assert (y_f - y_f_ref).abs().max().item() < 1e-4 + + +@pytest.mark.skipif(not CUDA, reason="HCM Triton kernel requires CUDA") +@pytest.mark.parametrize("B", [1, 2]) +@pytest.mark.parametrize("L", [1024, 8192, 32768]) +def test_hcm_bias_residual_matches_oracle(B: int, L: int) -> None: + """ + The fused bias-residual matches (y + u * bias[:, None]).to(u.dtype) at + evo2_7b shapes; B=2 exercises the per-channel bias broadcast over batch. + """ + torch.manual_seed(0) + D = 4096 + y = torch.randn(B, D, L, dtype=torch.float32, device="cuda") + u = torch.randn(B, D, L, dtype=torch.float32, device="cuda") + bias = torch.randn(D, dtype=torch.float32, device="cuda") + + out = _hcm_bias_residual(y, u, bias) + out_ref = (y + u * bias.unsqueeze(-1)).to(u.dtype) + + assert out.shape == y.shape + assert out.dtype == u.dtype + max_diff = (out - out_ref).abs().max().item() + mean_diff = (out - out_ref).abs().mean().item() + assert max_diff < 1e-4, f"max_diff={max_diff:.2e}" + assert mean_diff < 1e-5, f"mean_diff={mean_diff:.2e}" + + +@pytest.mark.skipif(not CUDA, reason="HCM Triton kernel requires CUDA") +def test_hcm_bias_residual_masks_ragged_tile() -> None: + """ + The kernel masks (D, L) tiles that BLOCK_D x BLOCK_L does not divide evenly. + """ + torch.manual_seed(0) + B, D, L = 1, 100, 300 + y = torch.randn(B, D, L, dtype=torch.float32, device="cuda") + u = torch.randn(B, D, L, dtype=torch.float32, device="cuda") + bias = torch.randn(D, dtype=torch.float32, device="cuda") + + out = _hcm_bias_residual(y, u, bias) + out_ref = (y + u * bias.unsqueeze(-1)).to(u.dtype) + assert (out - out_ref).abs().max().item() < 1e-4 diff --git a/vortex/ops/hcm_interface.py b/vortex/ops/hcm_interface.py index 44764d0..0e1d3b9 100644 --- a/vortex/ops/hcm_interface.py +++ b/vortex/ops/hcm_interface.py @@ -8,9 +8,9 @@ so the win is launch-count: the elementwise glue around the three cuFFT calls is fused into Triton kernels. -This module currently provides _hcm_complex_mul -- stage 3 of fftconv_func, -the broadcast complex product u_f * k_f, with stage 1's 1/fft_size filter -normalisation folded in. +It provides _hcm_complex_mul (stage 3 of fftconv_func -- the broadcast +complex product u_f * k_f, with stage 1's 1/fft_size scale folded in) and +_hcm_bias_residual (stage 5 -- the skip-residual add y + u * bias). """ from typing import Callable @@ -19,18 +19,27 @@ import triton import triton.language as tl -# Autotuned search space for the flat (D, F) tile. The op is memory-bound -# elementwise, so the winner is whichever BLOCK best saturates bandwidth for -# a given problem size -- Triton benchmarks these once per DF and caches it. -_AUTOTUNE_CONFIGS: list[triton.Config] = [ +# Autotuned search spaces -- both kernels are memory-bound elementwise work, +# so the winning tile is whichever best saturates bandwidth at a given shape. +# Triton benchmarks each set once per shape key and caches the winner. +_COMPLEX_MUL_CONFIGS: list[triton.Config] = [ triton.Config({"BLOCK": 256}, num_warps=2), triton.Config({"BLOCK": 512}, num_warps=4), triton.Config({"BLOCK": 1024}, num_warps=4), triton.Config({"BLOCK": 2048}, num_warps=8), ] +_BIAS_RESIDUAL_CONFIGS: list[triton.Config] = [ + triton.Config({"BLOCK_D": 32, "BLOCK_L": 64}, num_warps=2), + triton.Config({"BLOCK_D": 64, "BLOCK_L": 64}, num_warps=4), + triton.Config({"BLOCK_D": 64, "BLOCK_L": 128}, num_warps=4), + triton.Config({"BLOCK_D": 128, "BLOCK_L": 64}, num_warps=4), + triton.Config({"BLOCK_D": 64, "BLOCK_L": 256}, num_warps=8), + triton.Config({"BLOCK_D": 128, "BLOCK_L": 128}, num_warps=8), +] + -@triton.autotune(configs=_AUTOTUNE_CONFIGS, key=["DF"]) +@triton.autotune(configs=_COMPLEX_MUL_CONFIGS, key=["DF"]) @triton.jit def _hcm_complex_mul_kernel( u_ptr, @@ -138,3 +147,106 @@ def _hcm_complex_mul( u_r.stride(0), ) return y_f + + +@triton.autotune(configs=_BIAS_RESIDUAL_CONFIGS, key=["D", "L"]) +@triton.jit +def _hcm_bias_residual_kernel( + y_ptr, + u_ptr, + bias_ptr, + out_ptr, + D, + L, + stride_b, + stride_d, + stride_l, + BLOCK_D: tl.constexpr, + BLOCK_L: tl.constexpr, +): + """ + Skip-residual add: out[b, d, l] = y[b, d, l] + u[b, d, l] * bias[d]. + + One program covers a (BLOCK_D, BLOCK_L) tile of one batch element. y, u + and out share a contiguous (B, D, L) layout; bias is per-channel, shape + (D,), broadcast over batch and length. The fp32 accumulator is cast to + out's dtype on the store -- fftconv_func's stage-6 .to(u.dtype) cast. + """ + pid_b = tl.program_id(0) + pid_d = tl.program_id(1) + pid_l = tl.program_id(2) + + offs_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) + offs_l = pid_l * BLOCK_L + tl.arange(0, BLOCK_L) + mask_d = offs_d < D + mask_l = offs_l < L + mask = mask_d[:, None] & mask_l[None, :] + + offs = pid_b * stride_b + offs_d[:, None] * stride_d + offs_l[None, :] * stride_l + y = tl.load(y_ptr + offs, mask=mask, other=0.0).to(tl.float32) + u = tl.load(u_ptr + offs, mask=mask, other=0.0).to(tl.float32) + bias = tl.load(bias_ptr + offs_d, mask=mask_d, other=0.0).to(tl.float32) + + acc = y + u * bias[:, None] + tl.store(out_ptr + offs, acc, mask=mask) + + +def _hcm_bias_residual( + y: torch.Tensor, u: torch.Tensor, bias: torch.Tensor +) -> torch.Tensor: + """ + Fused skip-residual add -- stage 5 of fftconv_func. + + Computes y + u * bias[:, None] and writes it at u's dtype, fusing + fftconv_func's broadcast multiply and residual add (and its stage-6 + dtype cast) into a single Triton launch. + + Args: + y (torch.Tensor): The irfft output, shape (B, D, L). + u (torch.Tensor): The activations, shape (B, D, L); its dtype is the + output dtype -- fftconv_func's stage-6 cast target. + bias (torch.Tensor): Per-channel skip gain, shape (D,), broadcast + over batch and length. + + Returns: + torch.Tensor: y + u * bias[:, None], shape (B, D, L), u's dtype. + + Raises: + ValueError: If y and u are not matching 3-D tensors, or bias is not + 1-D of length D. + """ + if y.dim() != 3 or u.dim() != 3: + raise ValueError( + f"expected 3-D y and u, got {tuple(y.shape)} and {tuple(u.shape)}" + ) + if y.shape != u.shape: + raise ValueError(f"y {tuple(y.shape)} and u {tuple(u.shape)} must match") + + B, D, L = u.shape + if bias.dim() != 1 or bias.shape[0] != D: + raise ValueError(f"bias {tuple(bias.shape)} must be 1-D of length D={D}") + + y = y.contiguous() + u = u.contiguous() + bias = bias.contiguous() + out: torch.Tensor = torch.empty_like(u) + + # BLOCK_D / BLOCK_L are supplied by @triton.autotune; the grid is a + # callable so it can read the chosen tile sizes from the winning config. + grid: Callable[[triton.Config], tuple[int, int, int]] = lambda meta: ( + B, + triton.cdiv(D, meta["BLOCK_D"]), + triton.cdiv(L, meta["BLOCK_L"]), + ) + _hcm_bias_residual_kernel[grid]( + y, + u, + bias, + out, + D, + L, + u.stride(0), + u.stride(1), + u.stride(2), + ) + return out From e7a72c8c76f77355459255ef29d4c4b0996e6779 Mon Sep 17 00:00:00 2001 From: alphakhaw Date: Tue, 19 May 2026 22:47:08 +0800 Subject: [PATCH 07/16] feat: add hcm_fft_conv, the HCM FFT-conv wrapper --- test/test_hcm_adapter.py | 61 +++++++++++++++++++++++++++++++++++ vortex/ops/hcm_interface.py | 64 +++++++++++++++++++++++++++++++++++-- 2 files changed, 122 insertions(+), 3 deletions(-) create mode 100644 test/test_hcm_adapter.py diff --git a/test/test_hcm_adapter.py b/test/test_hcm_adapter.py new file mode 100644 index 0000000..3ca671e --- /dev/null +++ b/test/test_hcm_adapter.py @@ -0,0 +1,61 @@ +""" +Tests for hcm_fft_conv -- the drop-in for the fir_length >= 128 (HCM) branch +of HyenaInferenceEngine.parallel_fir. + +The oracle is upstream fftconv_func: with use_hcm_kernel off the engine calls +it directly, so hcm_fft_conv must reproduce it -- the flag is a no-op. +""" + +import pytest +import torch + +from vortex.model.engine import fftconv_func +from vortex.ops.hcm_interface import hcm_fft_conv + +CUDA: bool = torch.cuda.is_available() + + +def _hcm_inputs(B: int, L: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Build (u, weight, bias) for an HCM fftconv call at evo2_7b shapes. + """ + torch.manual_seed(0) + D, K = 4096, 128 + u = torch.randn(B, D, L, dtype=torch.float32, device="cuda") + weight = torch.randn(D, 1, K, dtype=torch.float32, device="cuda") + bias = torch.randn(D, dtype=torch.float32, device="cuda") + return u, weight, bias + + +@pytest.mark.skipif(not CUDA, reason="hcm_fft_conv requires CUDA") +@pytest.mark.parametrize("B", [1, 2]) +@pytest.mark.parametrize("L", [2048, 8192, 32768]) +def test_hcm_fft_conv_matches_fftconv_func(B: int, L: int) -> None: + """ + hcm_fft_conv reproduces fftconv_func on the non-bidirectional HCM path; + B=2 exercises the filter broadcast over the batch. + """ + u, weight, bias = _hcm_inputs(B, L) + + z = hcm_fft_conv(u, weight, bias, None, gelu=False, bidirectional=False) + z_ref = fftconv_func(u, weight, bias, None, gelu=False, bidirectional=False) + + assert z.shape == z_ref.shape == u.shape + assert z.dtype == z_ref.dtype + max_diff = (z - z_ref).abs().max().item() + mean_diff = (z - z_ref).abs().mean().item() + assert max_diff < 1e-2, f"max_diff={max_diff:.2e}" + assert mean_diff < 1e-3, f"mean_diff={mean_diff:.2e}" + + +@pytest.mark.skipif(not CUDA, reason="hcm_fft_conv requires CUDA") +def test_hcm_fft_conv_rejects_unsupported_paths() -> None: + """ + hcm_fft_conv raises on bidirectional or reverse-filter calls -- paths the + HCM dispatch never exercises. + """ + u, weight, bias = _hcm_inputs(1, 2048) + with pytest.raises(NotImplementedError): + hcm_fft_conv(u, weight, bias, None, bidirectional=True) + with pytest.raises(NotImplementedError): + hcm_fft_conv(u, weight, bias, None, k_rev=weight) diff --git a/vortex/ops/hcm_interface.py b/vortex/ops/hcm_interface.py index 0e1d3b9..bb5728c 100644 --- a/vortex/ops/hcm_interface.py +++ b/vortex/ops/hcm_interface.py @@ -8,9 +8,9 @@ so the win is launch-count: the elementwise glue around the three cuFFT calls is fused into Triton kernels. -It provides _hcm_complex_mul (stage 3 of fftconv_func -- the broadcast -complex product u_f * k_f, with stage 1's 1/fft_size scale folded in) and -_hcm_bias_residual (stage 5 -- the skip-residual add y + u * bias). +It provides hcm_fft_conv -- a drop-in for fftconv_func -- built on two fused +stage kernels: _hcm_complex_mul (stages 1 + 3, the scaled spectral product +u_f * k_f) and _hcm_bias_residual (stage 5, the skip-residual add y + u*bias). """ from typing import Callable @@ -250,3 +250,61 @@ def _hcm_bias_residual( u.stride(2), ) return out + + +def hcm_fft_conv( + u: torch.Tensor, + k: torch.Tensor, + D: torch.Tensor, + dropout_mask: torch.Tensor | None, + gelu: bool = True, + k_rev: torch.Tensor | None = None, + bidirectional: bool = False, + print_activations: bool = False, + layer_idx: int | None = None, + **kwargs, +) -> torch.Tensor: + """ + Fused HCM FFT-convolution -- a drop-in for fftconv_func. + + Reproduces fftconv_func's non-bidirectional inference path with the + elementwise glue fused into Triton kernels: cuFFT keeps the three + transforms, _hcm_complex_mul does the scaled spectral product (stages + 1 + 3) and _hcm_bias_residual the skip-residual add (stage 5). The + signature matches fftconv_func so the engine dispatch is a one-line swap. + + Args: + u (torch.Tensor): Input activations, shape (B, D, L). + k (torch.Tensor): Filter, shape (D, 1, K) (squeezable to (D, K)). + D (torch.Tensor): Per-channel skip-connection bias, shape (D,). + dropout_mask (torch.Tensor | None): Unused on the inference path; + kept for fftconv_func parity. + gelu (bool): Unused -- fftconv_func never applies it; kept for parity. + k_rev (torch.Tensor | None): Reverse filter; must be None -- the HCM + dispatch never sets it. + bidirectional (bool): Must be False -- the HCM branch is causal. + print_activations (bool): Accepted for parity; this path does not log. + layer_idx (int | None): Accepted for parity; unused. + + Returns: + torch.Tensor: y + u * D[:, None], shape (B, D, L), u's dtype -- + identical in shape, dtype and value to fftconv_func. + + Raises: + NotImplementedError: If bidirectional is True or k_rev is set; the + HCM dispatch never exercises those paths. + """ + if bidirectional or k_rev is not None: + raise NotImplementedError("hcm_fft_conv handles only the causal, non-reverse path") + + seqlen = u.shape[-1] + fft_size = 2 * seqlen + + # rfft(k) reshaped to (1, D, F) for the batch broadcast -- inlined, not + # adjust_filter_shape_for_broadcast, to avoid an engine import cycle. + k_f = torch.fft.rfft(k, n=fft_size).squeeze().unsqueeze(0) + u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size) # stage 2 + + prod = _hcm_complex_mul(u_f, k_f, fft_size) # stages 1 + 3 + y = torch.fft.irfft(prod, n=fft_size, norm="forward")[..., :seqlen] # stage 4 + return _hcm_bias_residual(y, u, D) # stages 5 + 6 From 665ed46b3434e3f6d2a2fed1356beb8ac958629f Mon Sep 17 00:00:00 2001 From: alphakhaw Date: Tue, 19 May 2026 22:47:08 +0800 Subject: [PATCH 08/16] feat: dispatch HCM behind the use_hcm_kernel flag --- test/test_hcm_e2e.py | 117 ++++++++++++++++++++++++++++++++++++++++ test/test_hcm_engine.py | 112 ++++++++++++++++++++++++++++++++++++++ vortex/model/engine.py | 20 +++++-- vortex/model/model.py | 1 + 4 files changed, 246 insertions(+), 4 deletions(-) create mode 100644 test/test_hcm_e2e.py create mode 100644 test/test_hcm_engine.py diff --git a/test/test_hcm_e2e.py b/test/test_hcm_e2e.py new file mode 100644 index 0000000..bf30166 --- /dev/null +++ b/test/test_hcm_e2e.py @@ -0,0 +1,117 @@ +""" +End-to-end test for the use_hcm_kernel flag inside a real Evo2 model. + +Loads an Evo2 checkpoint once and runs the same forward twice -- with the +HCM kernel off (stock fftconv_func) and on (the fused Triton HCM FFT-conv). +The flag changes only the HCM conv implementation -- a fused kernel that is +numerically equivalent to fftconv_func but not bit-exact -- so the model's +token predictions must agree and the logit vectors stay near-parallel. + +The model id defaults to evo2_7b (the variant cached on the dev box) and is +overridable with VK_E2E_MODEL. The checkpoint is large; the test skips +cleanly if it cannot be loaded. +""" + +import os + +import pytest +import torch + +from vortex.model.engine import HyenaInferenceEngine + +CUDA: bool = torch.cuda.is_available() +_MODEL_ID: str = os.environ.get("VK_E2E_MODEL", "evo2_7b") +_SEQ_LEN: int = 2048 + + +@pytest.fixture(scope="module") +def evo2_model(): + """ + Load the Evo2 model once for the module, or skip if unavailable. + """ + if not CUDA: + pytest.skip("Evo2 e2e test requires CUDA") + try: + from evo2 import Evo2 + except ImportError as exc: # pragma: no cover - depends on optional dep + pytest.skip(f"evo2 not installed: {exc}") + try: + return Evo2(_MODEL_ID) + except Exception as exc: # noqa: BLE001 - any load failure -> skip, not fail + pytest.skip(f"could not load Evo2({_MODEL_ID!r}): {exc}") + + +def _set_hcm_kernel(model, enabled: bool) -> int: + """ + Flip use_hcm_kernel on every HyenaInferenceEngine reachable from model. + + The flag is read off the engine instance (normally set from the + use_hcm_kernel config key at build time); toggling it on the loaded + model lets one checkpoint serve both halves of the comparison. + + Args: + model: A loaded Evo2 model. + enabled (bool): Target value for use_hcm_kernel. + + Returns: + The number of HyenaInferenceEngine instances touched. + """ + root = getattr(model, "model", model) + touched = 0 + for module in root.modules(): + engine = getattr(module, "engine", None) + if isinstance(engine, HyenaInferenceEngine): + engine.use_hcm_kernel = enabled + touched += 1 + return touched + + +def _logits(model, input_ids: torch.Tensor) -> torch.Tensor: + """ + Run a forward pass and return the logits tensor as fp32. + + Evo2 wraps its output in nested tuples, so descend to the first tensor. + """ + with torch.no_grad(): + out = model(input_ids) + while isinstance(out, (tuple, list)): + out = out[0] + return out.float() + + +@pytest.mark.skipif(not CUDA, reason="Evo2 e2e test requires CUDA") +def test_vk_hcm_e2e_matches_baseline(evo2_model) -> None: + """ + A full Evo2 forward is behaviourally unchanged when use_hcm_kernel swaps + in the fused HCM kernel. + + hcm_fft_conv is numerically equivalent to fftconv_func (~2e-7 relative), + not bit-exact. One HCM layer carries ~1e5-magnitude activations, so that + relative precision becomes a sizeable absolute logit difference once the + downstream blocks amplify it -- an absolute logit tolerance is meaningless + here. The behavioural invariant is tested instead: the model predicts the + same tokens and the logit vectors stay near-parallel. + """ + torch.manual_seed(0) + input_ids = torch.randint(1, 5, (1, _SEQ_LEN), dtype=torch.int, device="cuda:0") + + try: + touched = _set_hcm_kernel(evo2_model, False) + assert touched > 0, "no HyenaInferenceEngine found in the Evo2 model" + logits_off = _logits(evo2_model, input_ids) + + _set_hcm_kernel(evo2_model, True) + logits_on = _logits(evo2_model, input_ids) + finally: + _set_hcm_kernel(evo2_model, False) + + assert logits_on.shape == logits_off.shape + + agreement = (logits_on.argmax(-1) == logits_off.argmax(-1)).float().mean().item() + cosine = torch.nn.functional.cosine_similarity( + logits_on.flatten(), logits_off.flatten(), dim=0 + ).item() + assert agreement >= 0.99, ( + f"use_hcm_kernel changed {(1 - agreement) * 100:.2f}% of token predictions" + ) + assert cosine >= 0.9999, f"use_hcm_kernel logits diverged: cosine={cosine:.6f}" diff --git a/test/test_hcm_engine.py b/test/test_hcm_engine.py new file mode 100644 index 0000000..7b83b8a --- /dev/null +++ b/test/test_hcm_engine.py @@ -0,0 +1,112 @@ +""" +Wiring tests for the use_hcm_kernel branch in HyenaInferenceEngine.parallel_fir. + +These exercise the fir_length >= 128 dispatch by calling parallel_fir directly +(no model weights needed). With the kernel enabled, the HCM cascade must +reproduce the stock fftconv_func output; the branch must not fire for the +HCS-length cascade. +""" + +import pytest +import torch +import torch.nn.functional as F + +from vortex.model.engine import HyenaInferenceEngine + +CUDA: bool = torch.cuda.is_available() + +# evo2_7b HCM cascade shapes: D=4096, fir_length=128. +B, D, K, GROUPS = 1, 4096, 128, 256 +DIMS: tuple[int, int, int, int, int] = (D, 32, D // 32, 16, GROUPS) +_CASCADE_KW: dict[str, bool | int | None] = dict( + groups=GROUPS, + gated_bias=False, + column_split_hyena=False, + dim_last=False, + fir_length=K, + gate=True, +) + + +def _hcm_inputs(L: int, dtype: torch.dtype): + """ + Build (u, weight, bias) for an HCM cascade parallel_fir call. + """ + u = torch.randn(B, 3 * D, L, dtype=dtype, device="cuda") + weight = torch.randn(D, 1, K, dtype=dtype, device="cuda") + bias = torch.randn(D, dtype=dtype, device="cuda") + return u, weight, bias + + +@pytest.mark.skipif(not CUDA, reason="HCM kernel requires CUDA") +@pytest.mark.parametrize("L", [2048, 8192]) +def test_vk_hcm_on_matches_baseline_fp32(L: int) -> None: + """ + use_hcm_kernel on reproduces the stock parallel_fir HCM output in fp32. + """ + torch.manual_seed(0) + engine = HyenaInferenceEngine(layer_idx=0) + u, weight, bias = _hcm_inputs(L, torch.float32) + + z_off, _ = engine.parallel_fir(F.conv1d, u, weight, bias, L, DIMS, **_CASCADE_KW) + + engine.use_hcm_kernel = True + z_on, _ = engine.parallel_fir(F.conv1d, u, weight, bias, L, DIMS, **_CASCADE_KW) + + assert z_on.shape == z_off.shape == (B, D, L) + assert (z_on - z_off).abs().max().item() < 1e-2 + + +@pytest.mark.skipif(not CUDA, reason="HCM kernel requires CUDA") +def test_vk_hcm_on_matches_baseline_bf16() -> None: + """ + use_hcm_kernel on reproduces the stock HCM output in bf16, the inference dtype. + """ + torch.manual_seed(0) + engine = HyenaInferenceEngine(layer_idx=0) + u, weight, bias = _hcm_inputs(2048, torch.bfloat16) + + z_off, _ = engine.parallel_fir(F.conv1d, u, weight, bias, 2048, DIMS, **_CASCADE_KW) + + engine.use_hcm_kernel = True + z_on, _ = engine.parallel_fir(F.conv1d, u, weight, bias, 2048, DIMS, **_CASCADE_KW) + + torch.testing.assert_close(z_on, z_off, rtol=2e-2, atol=2e-2) + + +@pytest.mark.skipif(not CUDA, reason="HCM kernel requires CUDA") +def test_vk_hcm_off_by_default() -> None: + """ + A fresh HyenaInferenceEngine has use_hcm_kernel False, so parallel_fir + takes the stock fftconv_func path with no behavioural change. + """ + torch.manual_seed(0) + engine = HyenaInferenceEngine(layer_idx=0) + assert engine.use_hcm_kernel is False + + u, weight, bias = _hcm_inputs(2048, torch.float32) + z_default, _ = engine.parallel_fir(F.conv1d, u, weight, bias, 2048, DIMS, **_CASCADE_KW) + + explicit_off = HyenaInferenceEngine(layer_idx=0, use_hcm_kernel=False) + z_explicit, _ = explicit_off.parallel_fir(F.conv1d, u, weight, bias, 2048, DIMS, **_CASCADE_KW) + + # both took the stock path, so they are bitwise identical + assert (z_default - z_explicit).abs().max().item() == 0.0 + + +@pytest.mark.skipif(not CUDA, reason="HCM kernel requires CUDA") +def test_vk_hcm_predicate_skips_hcs_calls() -> None: + """ + The branch matches only the fir_length >= 128 cascade -- an HCS-length + call (fir_length=7) still routes through the stock F.conv1d path. + """ + torch.manual_seed(0) + engine = HyenaInferenceEngine(layer_idx=0, use_hcm_kernel=True) + + u = torch.randn(B, 3 * D, 1024, dtype=torch.float32, device="cuda") + weight = torch.randn(D, 1, 7, dtype=torch.float32, device="cuda") + bias = torch.randn(D, dtype=torch.float32, device="cuda") + kw = dict(_CASCADE_KW) + kw["fir_length"] = 7 + z, _ = engine.parallel_fir(F.conv1d, u, weight, bias, 1024, DIMS, **kw) + assert z.shape == (B, D, 1024) diff --git a/vortex/model/engine.py b/vortex/model/engine.py index 3b5bca0..86ff43f 100644 --- a/vortex/model/engine.py +++ b/vortex/model/engine.py @@ -8,14 +8,17 @@ from vortex.model.utils import column_split from vortex.logging import activations_logger -# vortex-kernels: optional fused Triton HCS kernel (refs #16, #76). hcs_interface -# imports triton -- a Linux/GPU-only optional dependency -- so this import is -# guarded; `import vortex` must still succeed where triton is not installed. +# vortex-kernels: optional fused Triton kernels (refs #16, #76), guarded so `import vortex` works without triton. try: from vortex.ops.hcs_interface import hcs_conv except ImportError: hcs_conv = None +try: + from vortex.ops.hcm_interface import hcm_fft_conv +except ImportError: + hcm_fft_conv = None + IIR_PREFILL_MODES = [ "recurrence", "modal-fft", @@ -123,6 +126,7 @@ def __init__( print_activations=False, hyena_flip_x1x2=False, use_hcs_kernel=False, + use_hcm_kernel=False, ) -> None: self.fir_fn = fir_fn assert iir_prefill_style in IIR_PREFILL_MODES, f"iir_prefill_style must be one of {IIR_PREFILL_MODES}" @@ -133,6 +137,7 @@ def __init__( self.print_activations = print_activations self.hyena_flip_x1x2 = hyena_flip_x1x2 self.use_hcs_kernel = use_hcs_kernel + self.use_hcm_kernel = use_hcm_kernel def parallel_fir( self, @@ -186,8 +191,15 @@ def parallel_fir( z = fir_fn(u)[:, :L] # B, L, D elif fir_length >= 128: + # vortex-kernels: opt-in fused Triton HCM FFT-conv (refs #16, #76). + # Flag off -> fftconv_func -> byte-identical to stock vortex. + fftconv = ( + hcm_fft_conv + if self.use_hcm_kernel and hcm_fft_conv is not None + else fftconv_func + ) with torch.autocast("cuda"): - z = fftconv_func( + z = fftconv( u.to(torch.float32), weight[:, :, :L].to(torch.float32), bias, diff --git a/vortex/model/model.py b/vortex/model/model.py index dba0d66..8388162 100644 --- a/vortex/model/model.py +++ b/vortex/model/model.py @@ -158,6 +158,7 @@ def __init__(self, config, layer_idx, hyena_filter_groups=None, fir_inner_filter print_activations=self.print_activations, hyena_flip_x1x2=config.get("hyena_flip_x1x2", False), use_hcs_kernel=config.get("use_hcs_kernel", False), + use_hcm_kernel=config.get("use_hcm_kernel", False), ) self.use_flash_depthwise = config.get("use_flash_depthwise", False) self.data_dtype = None From 3ce606574df08163e441123df2ad9812f5bf8be1 Mon Sep 17 00:00:00 2001 From: alphakhaw Date: Wed, 20 May 2026 00:11:12 +0800 Subject: [PATCH 09/16] feat: add the tiled HCL compute_filter kernel --- test/test_hcl_kernel.py | 100 +++++++++++++++++++++++++++ vortex/ops/hcl_interface.py | 130 ++++++++++++++++++++++++++++++++++++ 2 files changed, 230 insertions(+) create mode 100644 test/test_hcl_kernel.py diff --git a/test/test_hcl_kernel.py b/test/test_hcl_kernel.py new file mode 100644 index 0000000..ee30ffa --- /dev/null +++ b/test/test_hcl_kernel.py @@ -0,0 +1,100 @@ +""" +Correctness and peak-memory tests for the HCL tiled compute_filter kernel. + +_hcl_compute_filter builds the modal filter h[d, l] = sum_s residues[d, s] * +exp(log_poles[d, s] * t[l]) without the (D, state_size, L) intermediate that +OOMs evo2_7b at L=131k. Correctness is checked against the explicit torch +reduction compute_filter (model.py) runs; a peak-allocation test confirms the +intermediate is never built. +""" + +import pytest +import torch + +from vortex.ops.hcl_interface import _hcl_compute_filter + +CUDA: bool = torch.cuda.is_available() + + +def _modal_filter_ref( + residues: torch.Tensor, log_poles: torch.Tensor, t: torch.Tensor +) -> torch.Tensor: + """ + Reference modal filter -- compute_filter's (D, S, L) reduction, (D, L) out. + """ + return (residues[:, :, None] * (log_poles[:, :, None] * t).exp()).sum(1) + + +def _hcl_inputs( + D: int, L: int, S: int = 16 +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Build (residues, log_poles, t) for a stable modal filter. + + log_poles are small and negative so exp(log_poles * t) decays over the + sequence without overflowing -- a stable SSM, as in a trained evo2. + """ + torch.manual_seed(0) + residues = torch.randn(D, S, dtype=torch.float32, device="cuda") + log_poles = -torch.rand(D, S, dtype=torch.float32, device="cuda") * (8.0 / L) + t = torch.arange(L, dtype=torch.float32, device="cuda") + return residues, log_poles, t + + +@pytest.mark.skipif(not CUDA, reason="HCL Triton kernel requires CUDA") +@pytest.mark.parametrize("L", [2048, 8192, 32768]) +def test_hcl_compute_filter_matches_oracle(L: int) -> None: + """ + The tiled compute_filter kernel matches the torch modal-filter reduction + at evo2_7b shapes (D=4096, state_size=16). + """ + residues, log_poles, t = _hcl_inputs(D=4096, L=L) + + h = _hcl_compute_filter(residues, log_poles, t) + h_ref = _modal_filter_ref(residues, log_poles, t) + + assert h.shape == h_ref.shape == (4096, L) + assert h.dtype == torch.float32 + max_diff = (h - h_ref).abs().max().item() + mean_diff = (h - h_ref).abs().mean().item() + assert max_diff < 1e-3, f"max_diff={max_diff:.2e}" + assert mean_diff < 1e-4, f"mean_diff={mean_diff:.2e}" + + +@pytest.mark.skipif(not CUDA, reason="HCL Triton kernel requires CUDA") +def test_hcl_compute_filter_masks_ragged_tile() -> None: + """ + The kernel masks (D, L) tiles that BLOCK_D x BLOCK_L does not divide evenly. + """ + residues, log_poles, t = _hcl_inputs(D=100, L=300) + + h = _hcl_compute_filter(residues, log_poles, t) + h_ref = _modal_filter_ref(residues, log_poles, t) + assert (h - h_ref).abs().max().item() < 1e-3 + + +@pytest.mark.skipif(not CUDA, reason="HCL Triton kernel requires CUDA") +def test_hcl_compute_filter_avoids_the_intermediate() -> None: + """ + The kernel's peak allocation stays well below the reference path, which + materialises the (D, state_size, L) intermediate -- direct evidence the + tiled reduction never builds it. + """ + residues, log_poles, t = _hcl_inputs(D=4096, L=32768) + + torch.cuda.synchronize() + torch.cuda.reset_peak_memory_stats() + _modal_filter_ref(residues, log_poles, t) + torch.cuda.synchronize() + ref_peak = torch.cuda.max_memory_allocated() + + torch.cuda.synchronize() + torch.cuda.reset_peak_memory_stats() + _hcl_compute_filter(residues, log_poles, t) + torch.cuda.synchronize() + kernel_peak = torch.cuda.max_memory_allocated() + + assert kernel_peak < ref_peak * 0.5, ( + f"kernel peak {kernel_peak / 1e9:.2f} GB is not below half the " + f"reference peak {ref_peak / 1e9:.2f} GB" + ) diff --git a/vortex/ops/hcl_interface.py b/vortex/ops/hcl_interface.py index e69de29..ac98fe4 100644 --- a/vortex/ops/hcl_interface.py +++ b/vortex/ops/hcl_interface.py @@ -0,0 +1,130 @@ +# pyright: reportAttributeAccessIssue=none +""" +HCL -- Hyena Cascade Long. + +Triton kernels for the long-filter (long_fir_threshold is None) FFT-conv path +of HyenaInferenceEngine.parallel_iir. HCL is the memory-unlock kernel: the +stock compute_filter materialises a (D, state_size, L) fp32 intermediate that +OOMs evo2_7b at L=131k. + +This module currently provides _hcl_compute_filter -- the tiled modal-filter +build that does the state-size reduction in-register, so that intermediate +never exists. +""" + +from typing import Callable + +import torch +import triton +import triton.language as tl + +# Autotuned search space for the (D, L) tile. Pure elementwise + a 16-term +# in-register reduction, so the winner is whichever tile best saturates +# bandwidth -- Triton benchmarks these once per (D, L) and caches the winner. +_COMPUTE_FILTER_CONFIGS: list[triton.Config] = [ + triton.Config({"BLOCK_D": 32, "BLOCK_L": 64}, num_warps=2), + triton.Config({"BLOCK_D": 64, "BLOCK_L": 64}, num_warps=4), + triton.Config({"BLOCK_D": 64, "BLOCK_L": 128}, num_warps=4), + triton.Config({"BLOCK_D": 128, "BLOCK_L": 64}, num_warps=4), + triton.Config({"BLOCK_D": 64, "BLOCK_L": 256}, num_warps=8), + triton.Config({"BLOCK_D": 128, "BLOCK_L": 128}, num_warps=8), +] + + +@triton.autotune(configs=_COMPUTE_FILTER_CONFIGS, key=["D", "L"]) +@triton.jit +def _hcl_compute_filter_kernel( + residues_ptr, + log_poles_ptr, + t_ptr, + h_ptr, + D, + L, + S: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_L: tl.constexpr, +): + """ + Modal filter: h[d, l] = sum_s residues[d, s] * exp(log_poles[d, s] * t[l]). + + One program covers a (BLOCK_D, BLOCK_L) tile of h. The state-size sum (S + terms) runs in the fp32 register accumulator, so the (D, S, L) intermediate + that OOMs the stock compute_filter at L=131k never exists. residues and + log_poles are (D, S) row-major; t is (L,); h is (D, L) row-major. + """ + pid_d = tl.program_id(0) + pid_l = tl.program_id(1) + + offs_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) + offs_l = pid_l * BLOCK_L + tl.arange(0, BLOCK_L) + mask_d = offs_d < D + mask_l = offs_l < L + + t_tile = tl.load(t_ptr + offs_l, mask=mask_l, other=0.0).to(tl.float32) + + acc = tl.zeros((BLOCK_D, BLOCK_L), dtype=tl.float32) + for s in tl.static_range(S): + r_s = tl.load(residues_ptr + offs_d * S + s, mask=mask_d, other=0.0).to( + tl.float32 + ) + lp_s = tl.load(log_poles_ptr + offs_d * S + s, mask=mask_d, other=0.0).to( + tl.float32 + ) + acc += r_s[:, None] * tl.exp(lp_s[:, None] * t_tile[None, :]) + + h_ptrs = h_ptr + offs_d[:, None] * L + offs_l[None, :] + tl.store(h_ptrs, acc, mask=mask_d[:, None] & mask_l[None, :]) + + +def _hcl_compute_filter( + residues: torch.Tensor, log_poles: torch.Tensor, t: torch.Tensor +) -> torch.Tensor: + """ + Tiled modal-filter build -- the HCL compute_filter without the OOM. + + Computes h[d, l] = sum_s residues[d, s] * exp(log_poles[d, s] * t[l]), the + (D, L) filter compute_filter builds, with the state-size sum done + in-register so the (D, state_size, L) intermediate never exists. + + Args: + residues (torch.Tensor): Modal residues, shape (D, S). + log_poles (torch.Tensor): Modal log-poles, shape (D, S); negative for + a stable (decaying) filter. + t (torch.Tensor): Time index [0, 1, ..., L-1], shape (L,). + + Returns: + torch.Tensor: The modal filter h, shape (D, L), fp32. + + Raises: + ValueError: If residues and log_poles are not matching 2-D tensors, + or t is not 1-D. + """ + if residues.dim() != 2 or log_poles.dim() != 2: + raise ValueError( + f"expected 2-D residues and log_poles, got {tuple(residues.shape)} " + f"and {tuple(log_poles.shape)}" + ) + if residues.shape != log_poles.shape: + raise ValueError( + f"residues {tuple(residues.shape)} and log_poles " + f"{tuple(log_poles.shape)} must match" + ) + if t.dim() != 1: + raise ValueError(f"expected 1-D t, got {tuple(t.shape)}") + + D, S = residues.shape + L: int = t.shape[0] + + residues = residues.contiguous().float() + log_poles = log_poles.contiguous().float() + t = t.contiguous().float() + h: torch.Tensor = torch.empty(D, L, dtype=torch.float32, device=residues.device) + + # BLOCK_D / BLOCK_L are supplied by @triton.autotune; the grid is a + # callable so it can read the chosen tile sizes from the winning config. + grid: Callable[[triton.Config], tuple[int, int]] = lambda meta: ( + triton.cdiv(D, meta["BLOCK_D"]), + triton.cdiv(L, meta["BLOCK_L"]), + ) + _hcl_compute_filter_kernel[grid](residues, log_poles, t, h, D, L, S) + return h From 6f11b62aa6f8534e1de7790d2962d15ad28a25ad Mon Sep 17 00:00:00 2001 From: alphakhaw Date: Wed, 20 May 2026 00:22:49 +0800 Subject: [PATCH 10/16] feat: add the HCL bias-residual-gate kernel --- test/test_hcl_kernel.py | 54 +++++++++++++-- vortex/ops/hcl_interface.py | 130 +++++++++++++++++++++++++++++++++--- 2 files changed, 171 insertions(+), 13 deletions(-) diff --git a/test/test_hcl_kernel.py b/test/test_hcl_kernel.py index ee30ffa..791f409 100644 --- a/test/test_hcl_kernel.py +++ b/test/test_hcl_kernel.py @@ -1,17 +1,18 @@ """ -Correctness and peak-memory tests for the HCL tiled compute_filter kernel. +Correctness and peak-memory tests for the HCL Triton kernels. _hcl_compute_filter builds the modal filter h[d, l] = sum_s residues[d, s] * exp(log_poles[d, s] * t[l]) without the (D, state_size, L) intermediate that -OOMs evo2_7b at L=131k. Correctness is checked against the explicit torch -reduction compute_filter (model.py) runs; a peak-allocation test confirms the -intermediate is never built. +OOMs evo2_7b at L=131k; a peak-allocation test confirms that intermediate is +never built. _hcl_bias_residual_gate fuses the FFT-conv epilogue +(y + x1v * bias[:, None]) * x2. Each kernel is checked against the explicit +torch expression parallel_iir runs. """ import pytest import torch -from vortex.ops.hcl_interface import _hcl_compute_filter +from vortex.ops.hcl_interface import _hcl_bias_residual_gate, _hcl_compute_filter CUDA: bool = torch.cuda.is_available() @@ -98,3 +99,46 @@ def test_hcl_compute_filter_avoids_the_intermediate() -> None: f"kernel peak {kernel_peak / 1e9:.2f} GB is not below half the " f"reference peak {ref_peak / 1e9:.2f} GB" ) + + +@pytest.mark.skipif(not CUDA, reason="HCL Triton kernel requires CUDA") +@pytest.mark.parametrize("B", [1, 2]) +@pytest.mark.parametrize("L", [2048, 8192, 32768]) +def test_hcl_bias_residual_gate_matches_oracle(B: int, L: int) -> None: + """ + The fused bias-residual-gate matches (y + x1v * bias[:, None]) * x2 at + evo2_7b shapes. + """ + torch.manual_seed(0) + D = 4096 + y = torch.randn(B, D, L, dtype=torch.float32, device="cuda") + x1v = torch.randn(B, D, L, dtype=torch.float32, device="cuda") + bias = torch.randn(D, dtype=torch.float32, device="cuda") + x2 = torch.randn(B, D, L, dtype=torch.float32, device="cuda") + + out = _hcl_bias_residual_gate(y, x1v, bias, x2) + out_ref = (y + x1v * bias.unsqueeze(-1)) * x2 + + assert out.shape == y.shape + assert out.dtype == x1v.dtype + max_diff = (out - out_ref).abs().max().item() + mean_diff = (out - out_ref).abs().mean().item() + assert max_diff < 1e-4, f"max_diff={max_diff:.2e}" + assert mean_diff < 1e-5, f"mean_diff={mean_diff:.2e}" + + +@pytest.mark.skipif(not CUDA, reason="HCL Triton kernel requires CUDA") +def test_hcl_bias_residual_gate_masks_ragged_tile() -> None: + """ + The kernel masks (D, L) tiles that BLOCK_D x BLOCK_L does not divide evenly. + """ + torch.manual_seed(0) + B, D, L = 1, 100, 300 + y = torch.randn(B, D, L, dtype=torch.float32, device="cuda") + x1v = torch.randn(B, D, L, dtype=torch.float32, device="cuda") + bias = torch.randn(D, dtype=torch.float32, device="cuda") + x2 = torch.randn(B, D, L, dtype=torch.float32, device="cuda") + + out = _hcl_bias_residual_gate(y, x1v, bias, x2) + out_ref = (y + x1v * bias.unsqueeze(-1)) * x2 + assert (out - out_ref).abs().max().item() < 1e-4 diff --git a/vortex/ops/hcl_interface.py b/vortex/ops/hcl_interface.py index ac98fe4..1e760db 100644 --- a/vortex/ops/hcl_interface.py +++ b/vortex/ops/hcl_interface.py @@ -7,9 +7,9 @@ stock compute_filter materialises a (D, state_size, L) fp32 intermediate that OOMs evo2_7b at L=131k. -This module currently provides _hcl_compute_filter -- the tiled modal-filter -build that does the state-size reduction in-register, so that intermediate -never exists. +It provides _hcl_compute_filter -- the tiled modal-filter build that does the +state-size reduction in-register so that intermediate never exists -- and +_hcl_bias_residual_gate, the fused FFT-conv epilogue (y + x1v * bias) * x2. """ from typing import Callable @@ -18,10 +18,10 @@ import triton import triton.language as tl -# Autotuned search space for the (D, L) tile. Pure elementwise + a 16-term -# in-register reduction, so the winner is whichever tile best saturates -# bandwidth -- Triton benchmarks these once per (D, L) and caches the winner. -_COMPUTE_FILTER_CONFIGS: list[triton.Config] = [ +# Autotuned 2-D BLOCK_D x BLOCK_L tile, shared by both kernels -- both are +# memory-bound elementwise work over a (D, L) grid, so the winner is whichever +# tile best saturates bandwidth, benchmarked once per (D, L) and cached. +_TILE_CONFIGS: list[triton.Config] = [ triton.Config({"BLOCK_D": 32, "BLOCK_L": 64}, num_warps=2), triton.Config({"BLOCK_D": 64, "BLOCK_L": 64}, num_warps=4), triton.Config({"BLOCK_D": 64, "BLOCK_L": 128}, num_warps=4), @@ -31,7 +31,7 @@ ] -@triton.autotune(configs=_COMPUTE_FILTER_CONFIGS, key=["D", "L"]) +@triton.autotune(configs=_TILE_CONFIGS, key=["D", "L"]) @triton.jit def _hcl_compute_filter_kernel( residues_ptr, @@ -128,3 +128,117 @@ def _hcl_compute_filter( ) _hcl_compute_filter_kernel[grid](residues, log_poles, t, h, D, L, S) return h + + +@triton.autotune(configs=_TILE_CONFIGS, key=["D", "L"]) +@triton.jit +def _hcl_bias_residual_gate_kernel( + y_ptr, + x1v_ptr, + bias_ptr, + x2_ptr, + out_ptr, + D, + L, + stride_b, + stride_d, + stride_l, + BLOCK_D: tl.constexpr, + BLOCK_L: tl.constexpr, +): + """ + Bias-residual + gate: out[b,d,l] = (y[b,d,l] + x1v[b,d,l]*bias[d]) * x2[b,d,l]. + + One program covers a (BLOCK_D, BLOCK_L) tile of one batch element. y, x1v, + x2 and out share a contiguous (B, D, L) layout; bias is per-channel, shape + (D,), broadcast over batch and length. The fp32 accumulator is cast to + out's dtype on the store -- parallel_iir's y.to(x1v.dtype) cast. + """ + pid_b = tl.program_id(0) + pid_d = tl.program_id(1) + pid_l = tl.program_id(2) + + offs_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) + offs_l = pid_l * BLOCK_L + tl.arange(0, BLOCK_L) + mask_d = offs_d < D + mask_l = offs_l < L + mask = mask_d[:, None] & mask_l[None, :] + + offs = pid_b * stride_b + offs_d[:, None] * stride_d + offs_l[None, :] * stride_l + y = tl.load(y_ptr + offs, mask=mask, other=0.0).to(tl.float32) + x1v = tl.load(x1v_ptr + offs, mask=mask, other=0.0).to(tl.float32) + x2 = tl.load(x2_ptr + offs, mask=mask, other=0.0).to(tl.float32) + bias = tl.load(bias_ptr + offs_d, mask=mask_d, other=0.0).to(tl.float32) + + acc = (y + x1v * bias[:, None]) * x2 + tl.store(out_ptr + offs, acc, mask=mask) + + +def _hcl_bias_residual_gate( + y: torch.Tensor, x1v: torch.Tensor, bias: torch.Tensor, x2: torch.Tensor +) -> torch.Tensor: + """ + Fused bias-residual + gate -- the HCL FFT-conv epilogue. + + Computes (y + x1v * bias[:, None]) * x2 and writes it at x1v's dtype -- + parallel_iir's post-conv `y = (y + x1v * D.unsqueeze(-1)) * x2`, fusing the + broadcast multiply, the residual add, the gate and the dtype cast into one + Triton launch. + + Args: + y (torch.Tensor): The irfft output, shape (B, D, L). + x1v (torch.Tensor): The conv input, shape (B, D, L); its dtype is the + output dtype. + bias (torch.Tensor): Per-channel skip gain, shape (D,), broadcast over + batch and length. + x2 (torch.Tensor): The post-gate stream, shape (B, D, L). + + Returns: + torch.Tensor: (y + x1v * bias[:, None]) * x2, shape (B, D, L), x1v's + dtype. + + Raises: + ValueError: If y, x1v and x2 are not matching 3-D tensors, or bias is + not 1-D of length D. + """ + if y.dim() != 3 or x1v.dim() != 3 or x2.dim() != 3: + raise ValueError( + f"expected 3-D y, x1v, x2, got {tuple(y.shape)}, " + f"{tuple(x1v.shape)}, {tuple(x2.shape)}" + ) + if not (y.shape == x1v.shape == x2.shape): + raise ValueError( + f"y {tuple(y.shape)}, x1v {tuple(x1v.shape)}, x2 {tuple(x2.shape)} " + f"must all match" + ) + + B, D, L = x1v.shape + if bias.dim() != 1 or bias.shape[0] != D: + raise ValueError(f"bias {tuple(bias.shape)} must be 1-D of length D={D}") + + y = y.contiguous() + x1v = x1v.contiguous() + x2 = x2.contiguous() + bias = bias.contiguous() + out: torch.Tensor = torch.empty_like(x1v) + + # BLOCK_D / BLOCK_L are supplied by @triton.autotune; the grid is a + # callable so it can read the chosen tile sizes from the winning config. + grid: Callable[[triton.Config], tuple[int, int, int]] = lambda meta: ( + B, + triton.cdiv(D, meta["BLOCK_D"]), + triton.cdiv(L, meta["BLOCK_L"]), + ) + _hcl_bias_residual_gate_kernel[grid]( + y, + x1v, + bias, + x2, + out, + D, + L, + x1v.stride(0), + x1v.stride(1), + x1v.stride(2), + ) + return out From 67365ef336a7d46fc64e8746cf9404c0b4440a44 Mon Sep 17 00:00:00 2001 From: alphakhaw Date: Wed, 20 May 2026 00:34:33 +0800 Subject: [PATCH 11/16] feat: add hcl_fft_conv, the HCL FFT-conv epilogue wrapper --- test/test_hcl_adapter.py | 58 +++++++++++++++++++++++++++++++++++++ vortex/ops/hcl_interface.py | 46 +++++++++++++++++++++++++++-- 2 files changed, 101 insertions(+), 3 deletions(-) create mode 100644 test/test_hcl_adapter.py diff --git a/test/test_hcl_adapter.py b/test/test_hcl_adapter.py new file mode 100644 index 0000000..91b1c21 --- /dev/null +++ b/test/test_hcl_adapter.py @@ -0,0 +1,58 @@ +""" +Tests for hcl_fft_conv -- the fused FFT-conv epilogue for parallel_iir's +long_fir_threshold-is-None (HCL) branch. + +The oracle is the stock branch: rfft(h)/fft_size, fft(x1v), X*H, irfft, then +the post-conv (y + x1v*D[:, None]) * x2. hcl_fft_conv reproduces it with the +_hcm_complex_mul and _hcl_bias_residual_gate kernels. +""" + +import pytest +import torch + +from vortex.ops.hcl_interface import hcl_fft_conv + +CUDA: bool = torch.cuda.is_available() + + +def _hcl_branch_ref( + h: torch.Tensor, + x1v: torch.Tensor, + x2: torch.Tensor, + D: torch.Tensor, + L: int, + fft_size: int, +) -> torch.Tensor: + """ + Pure-torch reference for parallel_iir's HCL FFT-conv branch + post-conv. + """ + H = torch.fft.rfft(h.to(torch.float32), n=fft_size) / fft_size + X = torch.fft.fft(x1v.to(torch.float32), n=fft_size)[..., : H.shape[-1]] + y = torch.fft.irfft(X * H, n=fft_size, norm="forward")[..., :L] + y = y.to(x1v.dtype) + return (y + x1v * D.unsqueeze(-1)) * x2 + + +@pytest.mark.skipif(not CUDA, reason="hcl_fft_conv requires CUDA") +@pytest.mark.parametrize("L", [2048, 8192, 32768]) +def test_hcl_fft_conv_matches_branch(L: int) -> None: + """ + hcl_fft_conv reproduces the stock parallel_iir HCL branch in fp32. + """ + torch.manual_seed(0) + D = 4096 + fft_size = 2 * L + h = torch.randn(1, D, L, dtype=torch.float32, device="cuda") + x1v = torch.randn(1, D, L, dtype=torch.float32, device="cuda") + x2 = torch.randn(1, D, L, dtype=torch.float32, device="cuda") + bias = torch.randn(D, dtype=torch.float32, device="cuda") + + y = hcl_fft_conv(h, x1v, x2, bias, L, fft_size) + y_ref = _hcl_branch_ref(h, x1v, x2, bias, L, fft_size) + + assert y.shape == y_ref.shape == (1, D, L) + assert y.dtype == y_ref.dtype + max_diff = (y - y_ref).abs().max().item() + mean_diff = (y - y_ref).abs().mean().item() + assert max_diff < 1e-2, f"max_diff={max_diff:.2e}" + assert mean_diff < 1e-3, f"mean_diff={mean_diff:.2e}" diff --git a/vortex/ops/hcl_interface.py b/vortex/ops/hcl_interface.py index 1e760db..8c29d20 100644 --- a/vortex/ops/hcl_interface.py +++ b/vortex/ops/hcl_interface.py @@ -7,9 +7,10 @@ stock compute_filter materialises a (D, state_size, L) fp32 intermediate that OOMs evo2_7b at L=131k. -It provides _hcl_compute_filter -- the tiled modal-filter build that does the -state-size reduction in-register so that intermediate never exists -- and -_hcl_bias_residual_gate, the fused FFT-conv epilogue (y + x1v * bias) * x2. +It provides hcl_fft_conv -- the fused FFT-conv epilogue for parallel_iir -- +built on _hcl_bias_residual_gate (the post-conv kernel) and HCM's complex +multiply; and _hcl_compute_filter, the tiled modal-filter build the model's +compute_filter swaps in so the (D, state_size, L) intermediate never exists. """ from typing import Callable @@ -18,6 +19,8 @@ import triton import triton.language as tl +from vortex.ops.hcm_interface import _hcm_complex_mul + # Autotuned 2-D BLOCK_D x BLOCK_L tile, shared by both kernels -- both are # memory-bound elementwise work over a (D, L) grid, so the winner is whichever # tile best saturates bandwidth, benchmarked once per (D, L) and cached. @@ -242,3 +245,40 @@ def _hcl_bias_residual_gate( x1v.stride(2), ) return out + + +def hcl_fft_conv( + h: torch.Tensor, + x1v: torch.Tensor, + x2: torch.Tensor, + D: torch.Tensor, + L: int, + fft_size: int, +) -> torch.Tensor: + """ + Fused HCL FFT-convolution epilogue. + + Reproduces parallel_iir's long_fir_threshold-is-None branch and its + post-conv in one path: cuFFT keeps the three transforms, _hcm_complex_mul + does the scaled spectral product X*H (stage 3, with stage 1's /fft_size + folded in), and _hcl_bias_residual_gate does the post-conv + (y + x1v*D[:, None]) * x2. Returns y fully gated and residualed. + + Args: + h (torch.Tensor): The modal filter, shape (1, D, L). + x1v (torch.Tensor): The conv input, shape (1, D, L). + x2 (torch.Tensor): The post-gate stream, shape (1, D, L). + D (torch.Tensor): Per-channel skip-connection bias, shape (D,). + L (int): Sequence length. + fft_size (int): The FFT length, n = 2 * L. + + Returns: + torch.Tensor: (y + x1v*D[:, None]) * x2, shape (1, D, L), x1v's dtype. + """ + H = torch.fft.rfft(h.to(dtype=torch.float32), n=fft_size) + X_s = torch.fft.fft(x1v.to(dtype=torch.float32), n=fft_size) + X = X_s[..., : H.shape[-1]] + + prod = _hcm_complex_mul(X, H, fft_size) + y = torch.fft.irfft(prod, n=fft_size, norm="forward")[..., :L] + return _hcl_bias_residual_gate(y, x1v, D, x2) From e0b820c15089dbe7414f796b283f98a5db74c0f0 Mon Sep 17 00:00:00 2001 From: alphakhaw Date: Wed, 20 May 2026 00:47:14 +0800 Subject: [PATCH 12/16] feat: dispatch HCL behind the use_hcl_kernel flag --- test/test_hcl_engine.py | 100 ++++++++++++++++++++++++++++++++++++++++ vortex/model/engine.py | 25 +++++++--- vortex/model/model.py | 18 +++++++- 3 files changed, 136 insertions(+), 7 deletions(-) create mode 100644 test/test_hcl_engine.py diff --git a/test/test_hcl_engine.py b/test/test_hcl_engine.py new file mode 100644 index 0000000..595a5f0 --- /dev/null +++ b/test/test_hcl_engine.py @@ -0,0 +1,100 @@ +""" +Wiring tests for the use_hcl_kernel branch in HyenaInferenceEngine.parallel_iir. + +These exercise the long_fir_threshold-is-None dispatch by calling parallel_iir +directly. With the kernel enabled, the HCL FFT-conv must reproduce the stock +branch; the branch must not fire when long_fir_threshold is set. +""" + +import pytest +import torch + +from vortex.model.engine import HyenaInferenceEngine + +CUDA: bool = torch.cuda.is_available() + +# evo2_7b HCL shapes: D=4096, state_size=16. +B, D, S = 1, 4096, 16 +DIMS: tuple[int, int, int, int, int] = (D, 32, D // 32, 16, 256) + + +def _hcl_inputs(L: int, dtype: torch.dtype): + """ + Build (z_pre, h, bias, poles, residues, t) for an HCL parallel_iir call. + """ + z_pre = torch.randn(B, 3 * D, L, dtype=dtype, device="cuda") + h = torch.randn(1, D, L, dtype=torch.float32, device="cuda") + bias = torch.randn(D, dtype=dtype, device="cuda") + poles = torch.randn(D, S, 1, dtype=torch.float32, device="cuda") + residues = torch.randn(D, S, dtype=torch.float32, device="cuda") + t = torch.arange(L, device="cuda") + return z_pre, h, bias, poles, residues, t + + +def _call(engine, z_pre, h, bias, L, poles, residues, t, **kw): + """ + Invoke parallel_iir on the HCL FFT path (long_fir_threshold None). + """ + return engine.parallel_iir( + z_pre, + h, + bias, + L, + poles=poles, + residues=residues, + t=t, + dims=DIMS, + layer_idx=0, + long_fir_threshold=kw.pop("long_fir_threshold", None), + **kw, + ) + + +@pytest.mark.skipif(not CUDA, reason="HCL kernel requires CUDA") +@pytest.mark.parametrize("L", [2048, 8192]) +def test_vk_hcl_on_matches_baseline_fp32(L: int) -> None: + """ + use_hcl_kernel on reproduces the stock parallel_iir HCL output in fp32. + """ + torch.manual_seed(0) + engine = HyenaInferenceEngine(layer_idx=0) + z_pre, h, bias, poles, residues, t = _hcl_inputs(L, torch.float32) + + y_off = _call(engine, z_pre, h, bias, L, poles, residues, t) + engine.use_hcl_kernel = True + y_on = _call(engine, z_pre, h, bias, L, poles, residues, t) + + assert y_on.shape == y_off.shape == (B, L, D) + assert (y_on - y_off).abs().max().item() < 1e-2 + + +@pytest.mark.skipif(not CUDA, reason="HCL kernel requires CUDA") +def test_vk_hcl_off_by_default() -> None: + """ + A fresh HyenaInferenceEngine has use_hcl_kernel False -- the stock path. + """ + torch.manual_seed(0) + engine = HyenaInferenceEngine(layer_idx=0) + assert engine.use_hcl_kernel is False + + z_pre, h, bias, poles, residues, t = _hcl_inputs(2048, torch.float32) + y_default = _call(engine, z_pre, h, bias, 2048, poles, residues, t) + + explicit_off = HyenaInferenceEngine(layer_idx=0, use_hcl_kernel=False) + y_explicit = _call(explicit_off, z_pre, h, bias, 2048, poles, residues, t) + + assert (y_default - y_explicit).abs().max().item() == 0.0 + + +@pytest.mark.skipif(not CUDA, reason="HCL kernel requires CUDA") +def test_vk_hcl_predicate_skips_long_fir() -> None: + """ + The branch matches only long_fir_threshold is None -- a set threshold + routes through the stock depthwise-conv path. + """ + torch.manual_seed(0) + engine = HyenaInferenceEngine(layer_idx=0, use_hcl_kernel=True) + z_pre, h, bias, poles, residues, t = _hcl_inputs(2048, torch.float32) + + y = _call(engine, z_pre, h, bias, 2048, poles, residues, t, long_fir_threshold=128) + assert y.shape == (B, 2048, D) diff --git a/vortex/model/engine.py b/vortex/model/engine.py index 86ff43f..afc985b 100644 --- a/vortex/model/engine.py +++ b/vortex/model/engine.py @@ -11,13 +11,12 @@ # vortex-kernels: optional fused Triton kernels (refs #16, #76), guarded so `import vortex` works without triton. try: from vortex.ops.hcs_interface import hcs_conv -except ImportError: - hcs_conv = None - -try: from vortex.ops.hcm_interface import hcm_fft_conv + from vortex.ops.hcl_interface import hcl_fft_conv except ImportError: + hcs_conv = None hcm_fft_conv = None + hcl_fft_conv = None IIR_PREFILL_MODES = [ "recurrence", @@ -127,6 +126,7 @@ def __init__( hyena_flip_x1x2=False, use_hcs_kernel=False, use_hcm_kernel=False, + use_hcl_kernel=False, ) -> None: self.fir_fn = fir_fn assert iir_prefill_style in IIR_PREFILL_MODES, f"iir_prefill_style must be one of {IIR_PREFILL_MODES}" @@ -138,6 +138,7 @@ def __init__( self.hyena_flip_x1x2 = hyena_flip_x1x2 self.use_hcs_kernel = use_hcs_kernel self.use_hcm_kernel = use_hcm_kernel + self.use_hcl_kernel = use_hcl_kernel def parallel_fir( self, @@ -337,7 +338,18 @@ def parallel_iir( x1v = x1 * v - if inference_params is not None and prefill_style == "recurrence": + # vortex-kernels: opt-in fused Triton HCL FFT-conv (refs #16, #76). + # Flag off, or any prefill/flashfft/long-fir case, takes the stock path. + _use_hcl = ( + self.use_hcl_kernel + and hcl_fft_conv is not None + and inference_params is None + and long_fir_threshold is None + and not (use_flashfft and L % 2 == 0) + ) + if _use_hcl: + y = hcl_fft_conv(h, x1v, x2, D, L, fft_size) + elif inference_params is not None and prefill_style == "recurrence": y = self.prefill_via_direct_recurrence( inference_params=inference_params, x1v=x1v, @@ -376,7 +388,8 @@ def parallel_iir( # if self.layer_idx == 2: # breakpoint() y = y.to(dtype=x1v.dtype) - y = (y + x1v * D.unsqueeze(-1)) * x2 + if not _use_hcl: + y = (y + x1v * D.unsqueeze(-1)) * x2 if self.print_activations: activations_logger.info(f"hyena filter: {h}, {h.min()}, {h.max()}") diff --git a/vortex/model/model.py b/vortex/model/model.py index 8388162..0bfe95e 100644 --- a/vortex/model/model.py +++ b/vortex/model/model.py @@ -36,6 +36,13 @@ from vortex.model.attention import MHA +# vortex-kernels: optional tiled HCL compute_filter (refs #16, #76), guarded so +# `import vortex` works without triton. +try: + from vortex.ops.hcl_interface import _hcl_compute_filter +except ImportError: + _hcl_compute_filter = None + if HAS_TE: from transformer_engine.common.recipe import Format, DelayedScaling @@ -159,6 +166,7 @@ def __init__(self, config, layer_idx, hyena_filter_groups=None, fir_inner_filter hyena_flip_x1x2=config.get("hyena_flip_x1x2", False), use_hcs_kernel=config.get("use_hcs_kernel", False), use_hcm_kernel=config.get("use_hcm_kernel", False), + use_hcl_kernel=config.get("use_hcl_kernel", False), ) self.use_flash_depthwise = config.get("use_flash_depthwise", False) self.data_dtype = None @@ -396,7 +404,15 @@ def compute_filter(self, L, device): self.residues.to(filter_dtype), self.log_poles.to(filter_dtype), ) - h = (residues[..., None] * (log_poles * self.t).exp()).sum(1)[None] # B, D, L + # vortex-kernels: opt-in tiled HCL filter build -- the in-register + # state-size sum avoids the (D, state_size, L) intermediate that OOMs + # evo2_7b at L=131k. Flag off -> the stock reduction, byte-identical. + if self.engine.use_hcl_kernel and _hcl_compute_filter is not None: + h = _hcl_compute_filter( + residues, log_poles.squeeze(-1), self.t.reshape(-1).float() + )[None] + else: + h = (residues[..., None] * (log_poles * self.t).exp()).sum(1)[None] # B, D, L return h, filter_dtype, log_poles, residues From 10828fe333f458c9c92da20a35f26a2e13cd1082 Mon Sep 17 00:00:00 2001 From: alphakhaw Date: Wed, 20 May 2026 01:04:01 +0800 Subject: [PATCH 13/16] test: add the HCL end-to-end tests --- test/test_hcl_e2e.py | 138 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 138 insertions(+) create mode 100644 test/test_hcl_e2e.py diff --git a/test/test_hcl_e2e.py b/test/test_hcl_e2e.py new file mode 100644 index 0000000..0b3a378 --- /dev/null +++ b/test/test_hcl_e2e.py @@ -0,0 +1,138 @@ +""" +End-to-end tests for the use_hcl_kernel flag inside a real Evo2 model. + +Loads an Evo2 checkpoint once and runs the same forward with the HCL kernel +off (stock compute_filter + the parallel_iir FFT branch) and on (the tiled +filter build + the fused FFT-conv epilogue). The kernel is numerically +equivalent, not bit-exact, so the behavioural invariant is tested: the model +predicts the same tokens and the logit vectors stay near-parallel. + +test_hcl_unlocks_131k is the headline -- the stock compute_filter OOMs at +L=131k on its (D, state_size, L) intermediate; the tiled kernel removes it. +It needs an ~80GB GPU (the sign-off pod) and skips on a smaller card. + +These load a large checkpoint -- slow relative to the kernel tests; run them +deliberately, not in the fast loop. +""" + +import os + +import pytest +import torch + +from vortex.model.engine import HyenaInferenceEngine + +CUDA: bool = torch.cuda.is_available() +_MODEL_ID: str = os.environ.get("VK_E2E_MODEL", "evo2_7b") +_SEQ_LEN: int = 2048 + + +@pytest.fixture(scope="module") +def evo2_model(): + """ + Load the Evo2 model once for the module, or skip if unavailable. + """ + if not CUDA: + pytest.skip("Evo2 e2e test requires CUDA") + try: + from evo2 import Evo2 + except ImportError as exc: # pragma: no cover - depends on optional dep + pytest.skip(f"evo2 not installed: {exc}") + try: + return Evo2(_MODEL_ID) + except Exception as exc: # noqa: BLE001 - any load failure -> skip, not fail + pytest.skip(f"could not load Evo2({_MODEL_ID!r}): {exc}") + + +def _set_hcl_kernel(model, enabled: bool) -> int: + """ + Flip use_hcl_kernel on every HyenaInferenceEngine reachable from model. + + Args: + model: A loaded Evo2 model. + enabled (bool): Target value for use_hcl_kernel. + + Returns: + The number of HyenaInferenceEngine instances touched. + """ + root = getattr(model, "model", model) + touched = 0 + for module in root.modules(): + engine = getattr(module, "engine", None) + if isinstance(engine, HyenaInferenceEngine): + engine.use_hcl_kernel = enabled + touched += 1 + return touched + + +def _logits(model, input_ids: torch.Tensor) -> torch.Tensor: + """ + Run a forward pass and return the logits tensor as fp32. + """ + with torch.no_grad(): + out = model(input_ids) + while isinstance(out, (tuple, list)): + out = out[0] + return out.float() + + +@pytest.mark.skipif(not CUDA, reason="Evo2 e2e test requires CUDA") +def test_vk_hcl_e2e_matches_baseline(evo2_model) -> None: + """ + A full Evo2 forward is behaviourally unchanged when use_hcl_kernel swaps + in the tiled filter build and the fused FFT-conv epilogue. + + The kernel is numerically equivalent to the stock path but not bit-exact, + so the test asserts prediction agreement (argmax + cosine), not an + absolute logit bound -- the same rationale as the HCM e2e. + """ + torch.manual_seed(0) + input_ids = torch.randint(1, 5, (1, _SEQ_LEN), dtype=torch.int, device="cuda:0") + + try: + touched = _set_hcl_kernel(evo2_model, False) + assert touched > 0, "no HyenaInferenceEngine found in the Evo2 model" + logits_off = _logits(evo2_model, input_ids) + + _set_hcl_kernel(evo2_model, True) + logits_on = _logits(evo2_model, input_ids) + finally: + _set_hcl_kernel(evo2_model, False) + + assert logits_on.shape == logits_off.shape + + agreement = (logits_on.argmax(-1) == logits_off.argmax(-1)).float().mean().item() + cosine = torch.nn.functional.cosine_similarity( + logits_on.flatten(), logits_off.flatten(), dim=0 + ).item() + assert agreement >= 0.99, ( + f"use_hcl_kernel changed {(1 - agreement) * 100:.2f}% of token predictions" + ) + assert cosine >= 0.9999, f"use_hcl_kernel logits diverged: cosine={cosine:.6f}" + + +@pytest.mark.skipif(not CUDA, reason="Evo2 e2e test requires CUDA") +def test_hcl_unlocks_131k(evo2_model) -> None: + """ + A full evo2_7b forward at L=131072 completes with use_hcl_kernel on. + + The stock compute_filter OOMs at this length on its (D, state_size, L) + fp32 intermediate (34 GiB at D=4096); the tiled kernel removes it. Needs + an ~80GB GPU (the H100 sign-off pod) -- on a smaller card this skips: + even with the kernel, the rest of the 131k forward will not fit. + """ + torch.manual_seed(0) + input_ids = torch.randint(1, 5, (1, 131072), dtype=torch.int, device="cuda:0") + + try: + touched = _set_hcl_kernel(evo2_model, True) + assert touched > 0, "no HyenaInferenceEngine found in the Evo2 model" + logits = _logits(evo2_model, input_ids) + except torch.cuda.OutOfMemoryError: + torch.cuda.empty_cache() + pytest.skip("L=131072 needs ~80GB; run on the H100 sign-off pod") + finally: + _set_hcl_kernel(evo2_model, False) + + assert logits.shape[1] == 131072 + assert torch.isfinite(logits).all() From 350fbb0786cb49f8b7321ced2e5764cc6ebd3b72 Mon Sep 17 00:00:00 2001 From: alphakhaw Date: Wed, 20 May 2026 11:51:22 +0800 Subject: [PATCH 14/16] refactor(ops): share HC tile configs; perf and safety fixes - vortex/ops/triton_common: BDL_TILE_CONFIGS plus bdl_grid_2d/3d helpers, imported by all three HC kernels (3 duplicate config lists collapsed). - HCM hcm_fft_conv: filter rfft uses .squeeze(1); .squeeze() would also collapse a D=1 axis and break the (1, D, F) broadcast contract. - HCL hcl_fft_conv: rfft on the real signal path replaces fft+slice (~half the cuFFT signal cost); e2e cosine vs stock stays >=0.9999. - HCL _hcl_compute_filter_kernel: explicit (residues, h) strides; launcher passes them instead of relying on implicit row-major arithmetic. - HCS kernel: dropped no-op .to(tl.float32) casts (hcs_conv already passes fp32); added mask_l to the load mask to skip past-L tile reads. - model.py: log_poles.squeeze(-1) -> .reshape(D, S) for shape robustness. - engine.py: tighter opt-in dispatch comments; clarify that hcl_fft_conv already applied the post-conv (y + x1v*D) * x2 gate. --- vortex/model/engine.py | 14 +++--- vortex/model/model.py | 8 ++-- vortex/ops/hcl_interface.py | 86 ++++++++++++++++--------------------- vortex/ops/hcm_interface.py | 82 +++++++++++++---------------------- vortex/ops/hcs_interface.py | 35 ++++----------- vortex/ops/triton_common.py | 47 ++++++++++++++++++++ 6 files changed, 133 insertions(+), 139 deletions(-) create mode 100644 vortex/ops/triton_common.py diff --git a/vortex/model/engine.py b/vortex/model/engine.py index afc985b..41db945 100644 --- a/vortex/model/engine.py +++ b/vortex/model/engine.py @@ -8,7 +8,8 @@ from vortex.model.utils import column_split from vortex.logging import activations_logger -# vortex-kernels: optional fused Triton kernels (refs #16, #76), guarded so `import vortex` works without triton. +# Optional fused Triton kernels (refs #16, #76). Guarded so `import vortex` +# still works on hosts without triton. try: from vortex.ops.hcs_interface import hcs_conv from vortex.ops.hcm_interface import hcm_fft_conv @@ -169,8 +170,7 @@ def parallel_fir( if self.hyena_flip_x1x2: x1, x2 = x2, x1 - # vortex-kernels: opt-in fused Triton HCS short conv (refs #16, #76). - # Matches only the gated short-filter cascade; flag off is a no-op. + # Opt-in HCS kernel for the gated short-filter cascade (refs #16, #76). if self.use_hcs_kernel and hcs_conv is not None and fir_length < 128 and groups: z = hcs_conv(x1, x2, v, weight, bias, gated_bias=gated_bias, padding_mask=padding_mask) fir_state = (x1 * v)[..., -fir_length + 1 :] if inference_params is not None else None @@ -192,8 +192,7 @@ def parallel_fir( z = fir_fn(u)[:, :L] # B, L, D elif fir_length >= 128: - # vortex-kernels: opt-in fused Triton HCM FFT-conv (refs #16, #76). - # Flag off -> fftconv_func -> byte-identical to stock vortex. + # Opt-in HCM FFT-conv (refs #16, #76); flag off falls back to fftconv_func. fftconv = ( hcm_fft_conv if self.use_hcm_kernel and hcm_fft_conv is not None @@ -338,8 +337,8 @@ def parallel_iir( x1v = x1 * v - # vortex-kernels: opt-in fused Triton HCL FFT-conv (refs #16, #76). - # Flag off, or any prefill/flashfft/long-fir case, takes the stock path. + # Opt-in HCL FFT-conv (refs #16, #76). Skipped during prefill, when + # flashfft owns even-L, or when long_fir_threshold pins the depthwise path. _use_hcl = ( self.use_hcl_kernel and hcl_fft_conv is not None @@ -388,6 +387,7 @@ def parallel_iir( # if self.layer_idx == 2: # breakpoint() y = y.to(dtype=x1v.dtype) + # hcl_fft_conv already applied the post-conv (y + x1v*D[:, None]) * x2. if not _use_hcl: y = (y + x1v * D.unsqueeze(-1)) * x2 diff --git a/vortex/model/model.py b/vortex/model/model.py index 0bfe95e..a6c44ce 100644 --- a/vortex/model/model.py +++ b/vortex/model/model.py @@ -404,12 +404,12 @@ def compute_filter(self, L, device): self.residues.to(filter_dtype), self.log_poles.to(filter_dtype), ) - # vortex-kernels: opt-in tiled HCL filter build -- the in-register - # state-size sum avoids the (D, state_size, L) intermediate that OOMs - # evo2_7b at L=131k. Flag off -> the stock reduction, byte-identical. + # vortex-kernels: opt-in tiled HCL filter build avoids the + # (D, state_size, L) intermediate that OOMs evo2_7b at L=131k. if self.engine.use_hcl_kernel and _hcl_compute_filter is not None: + D, S = residues.shape[0], residues.shape[1] h = _hcl_compute_filter( - residues, log_poles.squeeze(-1), self.t.reshape(-1).float() + residues, log_poles.reshape(D, S), self.t.reshape(-1).float() )[None] else: h = (residues[..., None] * (log_poles * self.t).exp()).sum(1)[None] # B, D, L diff --git a/vortex/ops/hcl_interface.py b/vortex/ops/hcl_interface.py index 8c29d20..6127d27 100644 --- a/vortex/ops/hcl_interface.py +++ b/vortex/ops/hcl_interface.py @@ -1,4 +1,3 @@ -# pyright: reportAttributeAccessIssue=none """ HCL -- Hyena Cascade Long. @@ -13,28 +12,15 @@ compute_filter swaps in so the (D, state_size, L) intermediate never exists. """ -from typing import Callable - import torch import triton import triton.language as tl from vortex.ops.hcm_interface import _hcm_complex_mul - -# Autotuned 2-D BLOCK_D x BLOCK_L tile, shared by both kernels -- both are -# memory-bound elementwise work over a (D, L) grid, so the winner is whichever -# tile best saturates bandwidth, benchmarked once per (D, L) and cached. -_TILE_CONFIGS: list[triton.Config] = [ - triton.Config({"BLOCK_D": 32, "BLOCK_L": 64}, num_warps=2), - triton.Config({"BLOCK_D": 64, "BLOCK_L": 64}, num_warps=4), - triton.Config({"BLOCK_D": 64, "BLOCK_L": 128}, num_warps=4), - triton.Config({"BLOCK_D": 128, "BLOCK_L": 64}, num_warps=4), - triton.Config({"BLOCK_D": 64, "BLOCK_L": 256}, num_warps=8), - triton.Config({"BLOCK_D": 128, "BLOCK_L": 128}, num_warps=8), -] +from vortex.ops.triton_common import BDL_TILE_CONFIGS, bdl_grid_2d, bdl_grid_3d -@triton.autotune(configs=_TILE_CONFIGS, key=["D", "L"]) +@triton.autotune(configs=BDL_TILE_CONFIGS, key=["D", "L"]) @triton.jit def _hcl_compute_filter_kernel( residues_ptr, @@ -43,6 +29,10 @@ def _hcl_compute_filter_kernel( h_ptr, D, L, + stride_rs_d, + stride_rs_s, + stride_h_d, + stride_h_l, S: tl.constexpr, BLOCK_D: tl.constexpr, BLOCK_L: tl.constexpr, @@ -50,10 +40,9 @@ def _hcl_compute_filter_kernel( """ Modal filter: h[d, l] = sum_s residues[d, s] * exp(log_poles[d, s] * t[l]). - One program covers a (BLOCK_D, BLOCK_L) tile of h. The state-size sum (S - terms) runs in the fp32 register accumulator, so the (D, S, L) intermediate - that OOMs the stock compute_filter at L=131k never exists. residues and - log_poles are (D, S) row-major; t is (L,); h is (D, L) row-major. + One program covers a (BLOCK_D, BLOCK_L) tile of h; the state-size sum (S + terms) runs in a fp32 register accumulator so the (D, S, L) intermediate + that OOMs the stock compute_filter at L=131k never exists. """ pid_d = tl.program_id(0) pid_l = tl.program_id(1) @@ -67,15 +56,12 @@ def _hcl_compute_filter_kernel( acc = tl.zeros((BLOCK_D, BLOCK_L), dtype=tl.float32) for s in tl.static_range(S): - r_s = tl.load(residues_ptr + offs_d * S + s, mask=mask_d, other=0.0).to( - tl.float32 - ) - lp_s = tl.load(log_poles_ptr + offs_d * S + s, mask=mask_d, other=0.0).to( - tl.float32 - ) + rs_offs = offs_d * stride_rs_d + s * stride_rs_s + r_s = tl.load(residues_ptr + rs_offs, mask=mask_d, other=0.0).to(tl.float32) + lp_s = tl.load(log_poles_ptr + rs_offs, mask=mask_d, other=0.0).to(tl.float32) acc += r_s[:, None] * tl.exp(lp_s[:, None] * t_tile[None, :]) - h_ptrs = h_ptr + offs_d[:, None] * L + offs_l[None, :] + h_ptrs = h_ptr + offs_d[:, None] * stride_h_d + offs_l[None, :] * stride_h_l tl.store(h_ptrs, acc, mask=mask_d[:, None] & mask_l[None, :]) @@ -123,17 +109,23 @@ def _hcl_compute_filter( t = t.contiguous().float() h: torch.Tensor = torch.empty(D, L, dtype=torch.float32, device=residues.device) - # BLOCK_D / BLOCK_L are supplied by @triton.autotune; the grid is a - # callable so it can read the chosen tile sizes from the winning config. - grid: Callable[[triton.Config], tuple[int, int]] = lambda meta: ( - triton.cdiv(D, meta["BLOCK_D"]), - triton.cdiv(L, meta["BLOCK_L"]), + _hcl_compute_filter_kernel[bdl_grid_2d(D, L)]( + residues, + log_poles, + t, + h, + D, + L, + residues.stride(0), + residues.stride(1), + h.stride(0), + h.stride(1), + S, ) - _hcl_compute_filter_kernel[grid](residues, log_poles, t, h, D, L, S) return h -@triton.autotune(configs=_TILE_CONFIGS, key=["D", "L"]) +@triton.autotune(configs=BDL_TILE_CONFIGS, key=["D", "L"]) @triton.jit def _hcl_bias_residual_gate_kernel( y_ptr, @@ -224,15 +216,7 @@ def _hcl_bias_residual_gate( x2 = x2.contiguous() bias = bias.contiguous() out: torch.Tensor = torch.empty_like(x1v) - - # BLOCK_D / BLOCK_L are supplied by @triton.autotune; the grid is a - # callable so it can read the chosen tile sizes from the winning config. - grid: Callable[[triton.Config], tuple[int, int, int]] = lambda meta: ( - B, - triton.cdiv(D, meta["BLOCK_D"]), - triton.cdiv(L, meta["BLOCK_L"]), - ) - _hcl_bias_residual_gate_kernel[grid]( + _hcl_bias_residual_gate_kernel[bdl_grid_3d(B, D, L)]( y, x1v, bias, @@ -258,11 +242,14 @@ def hcl_fft_conv( """ Fused HCL FFT-convolution epilogue. - Reproduces parallel_iir's long_fir_threshold-is-None branch and its - post-conv in one path: cuFFT keeps the three transforms, _hcm_complex_mul - does the scaled spectral product X*H (stage 3, with stage 1's /fft_size - folded in), and _hcl_bias_residual_gate does the post-conv - (y + x1v*D[:, None]) * x2. Returns y fully gated and residualed. + Reproduces parallel_iir's long_fir_threshold-is-None branch in one path: + cuFFT keeps the three transforms, _hcm_complex_mul does the spectral + product X*H scaled by 1/fft_size, and _hcl_bias_residual_gate does the + post-conv (y + x1v*D[:, None]) * x2. + + The signal FFT uses rfft -- mathematically identical to fft on a real + input but produces only fft_size//2 + 1 bins, so cuFFT does ~half the + work the stock parallel_iir's `fft` + slice path does. Args: h (torch.Tensor): The modal filter, shape (1, D, L). @@ -276,8 +263,7 @@ def hcl_fft_conv( torch.Tensor: (y + x1v*D[:, None]) * x2, shape (1, D, L), x1v's dtype. """ H = torch.fft.rfft(h.to(dtype=torch.float32), n=fft_size) - X_s = torch.fft.fft(x1v.to(dtype=torch.float32), n=fft_size) - X = X_s[..., : H.shape[-1]] + X = torch.fft.rfft(x1v.to(dtype=torch.float32), n=fft_size) prod = _hcm_complex_mul(X, H, fft_size) y = torch.fft.irfft(prod, n=fft_size, norm="forward")[..., :L] diff --git a/vortex/ops/hcm_interface.py b/vortex/ops/hcm_interface.py index bb5728c..69a1afc 100644 --- a/vortex/ops/hcm_interface.py +++ b/vortex/ops/hcm_interface.py @@ -1,4 +1,3 @@ -# pyright: reportAttributeAccessIssue=none """ HCM -- Hyena Cascade Medium. @@ -13,15 +12,16 @@ u_f * k_f) and _hcm_bias_residual (stage 5, the skip-residual add y + u*bias). """ -from typing import Callable +from collections.abc import Callable import torch import triton import triton.language as tl -# Autotuned search spaces -- both kernels are memory-bound elementwise work, -# so the winning tile is whichever best saturates bandwidth at a given shape. -# Triton benchmarks each set once per shape key and caches the winner. +from vortex.ops.triton_common import BDL_TILE_CONFIGS, bdl_grid_3d + +# 1-D BLOCK sweep for the complex multiply -- the only HC kernel that flattens +# (D, F) to a single axis, so it can't share BDL_TILE_CONFIGS. _COMPLEX_MUL_CONFIGS: list[triton.Config] = [ triton.Config({"BLOCK": 256}, num_warps=2), triton.Config({"BLOCK": 512}, num_warps=4), @@ -29,15 +29,6 @@ triton.Config({"BLOCK": 2048}, num_warps=8), ] -_BIAS_RESIDUAL_CONFIGS: list[triton.Config] = [ - triton.Config({"BLOCK_D": 32, "BLOCK_L": 64}, num_warps=2), - triton.Config({"BLOCK_D": 64, "BLOCK_L": 64}, num_warps=4), - triton.Config({"BLOCK_D": 64, "BLOCK_L": 128}, num_warps=4), - triton.Config({"BLOCK_D": 128, "BLOCK_L": 64}, num_warps=4), - triton.Config({"BLOCK_D": 64, "BLOCK_L": 256}, num_warps=8), - triton.Config({"BLOCK_D": 128, "BLOCK_L": 128}, num_warps=8), -] - @triton.autotune(configs=_COMPLEX_MUL_CONFIGS, key=["DF"]) @triton.jit @@ -132,9 +123,7 @@ def _hcm_complex_mul( y_r = torch.view_as_real(y_f) DF = D * F - # BLOCK is supplied by @triton.autotune; the grid is a callable so it can - # read the chosen tile size from the winning config. - grid: Callable[[triton.Config], tuple[int, int]] = lambda meta: ( + grid: Callable[[dict], tuple[int, int]] = lambda meta: ( triton.cdiv(DF, meta["BLOCK"]), B, ) @@ -149,7 +138,7 @@ def _hcm_complex_mul( return y_f -@triton.autotune(configs=_BIAS_RESIDUAL_CONFIGS, key=["D", "L"]) +@triton.autotune(configs=BDL_TILE_CONFIGS, key=["D", "L"]) @triton.jit def _hcm_bias_residual_kernel( y_ptr, @@ -230,15 +219,7 @@ def _hcm_bias_residual( u = u.contiguous() bias = bias.contiguous() out: torch.Tensor = torch.empty_like(u) - - # BLOCK_D / BLOCK_L are supplied by @triton.autotune; the grid is a - # callable so it can read the chosen tile sizes from the winning config. - grid: Callable[[triton.Config], tuple[int, int, int]] = lambda meta: ( - B, - triton.cdiv(D, meta["BLOCK_D"]), - triton.cdiv(L, meta["BLOCK_L"]), - ) - _hcm_bias_residual_kernel[grid]( + _hcm_bias_residual_kernel[bdl_grid_3d(B, D, L)]( y, u, bias, @@ -268,43 +249,40 @@ def hcm_fft_conv( Fused HCM FFT-convolution -- a drop-in for fftconv_func. Reproduces fftconv_func's non-bidirectional inference path with the - elementwise glue fused into Triton kernels: cuFFT keeps the three - transforms, _hcm_complex_mul does the scaled spectral product (stages - 1 + 3) and _hcm_bias_residual the skip-residual add (stage 5). The - signature matches fftconv_func so the engine dispatch is a one-line swap. + elementwise glue fused into Triton kernels. cuFFT keeps the three + transforms; _hcm_complex_mul does the scaled spectral product and + _hcm_bias_residual does the skip-residual add. The signature mirrors + fftconv_func so the engine dispatch is a one-line swap; the trailing + args (dropout_mask, gelu, k_rev, bidirectional, print_activations, + layer_idx, **kwargs) exist only for that parity. Args: u (torch.Tensor): Input activations, shape (B, D, L). - k (torch.Tensor): Filter, shape (D, 1, K) (squeezable to (D, K)). + k (torch.Tensor): Filter, shape (D, 1, K). D (torch.Tensor): Per-channel skip-connection bias, shape (D,). - dropout_mask (torch.Tensor | None): Unused on the inference path; - kept for fftconv_func parity. - gelu (bool): Unused -- fftconv_func never applies it; kept for parity. - k_rev (torch.Tensor | None): Reverse filter; must be None -- the HCM - dispatch never sets it. - bidirectional (bool): Must be False -- the HCM branch is causal. - print_activations (bool): Accepted for parity; this path does not log. - layer_idx (int | None): Accepted for parity; unused. Returns: - torch.Tensor: y + u * D[:, None], shape (B, D, L), u's dtype -- - identical in shape, dtype and value to fftconv_func. + torch.Tensor: y + u * D[:, None], shape (B, D, L), u's dtype. Raises: - NotImplementedError: If bidirectional is True or k_rev is set; the - HCM dispatch never exercises those paths. + NotImplementedError: If bidirectional is True or k_rev is set; + unsupported paths the HCM dispatch never hits. """ if bidirectional or k_rev is not None: - raise NotImplementedError("hcm_fft_conv handles only the causal, non-reverse path") + raise NotImplementedError( + "hcm_fft_conv handles only the causal, non-reverse path" + ) seqlen = u.shape[-1] fft_size = 2 * seqlen - # rfft(k) reshaped to (1, D, F) for the batch broadcast -- inlined, not - # adjust_filter_shape_for_broadcast, to avoid an engine import cycle. - k_f = torch.fft.rfft(k, n=fft_size).squeeze().unsqueeze(0) - u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size) # stage 2 + # rfft(k) reshaped to (1, D, F) for the batch broadcast -- inlined to avoid + # the adjust_filter_shape_for_broadcast import cycle. squeeze(1) drops only + # the channel-group axis; .squeeze() with no arg would also collapse a D=1 + # case and break the (1, D, F) broadcast contract. + k_f = torch.fft.rfft(k, n=fft_size).squeeze(1).unsqueeze(0) + u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size) - prod = _hcm_complex_mul(u_f, k_f, fft_size) # stages 1 + 3 - y = torch.fft.irfft(prod, n=fft_size, norm="forward")[..., :seqlen] # stage 4 - return _hcm_bias_residual(y, u, D) # stages 5 + 6 + prod = _hcm_complex_mul(u_f, k_f, fft_size) + y = torch.fft.irfft(prod, n=fft_size, norm="forward")[..., :seqlen] + return _hcm_bias_residual(y, u, D) diff --git a/vortex/ops/hcs_interface.py b/vortex/ops/hcs_interface.py index 78b75b6..e6c2f85 100644 --- a/vortex/ops/hcs_interface.py +++ b/vortex/ops/hcs_interface.py @@ -1,4 +1,3 @@ -# pyright: reportAttributeAccessIssue=none """ HCS -- Hyena Cascade Short. @@ -11,26 +10,14 @@ adapter that wires it behind the use_hcs_kernel config flag is added alongside. """ -from typing import Callable - import torch import triton import triton.language as tl -# Autotuned search space for the conv kernel's register-tile sizes. Triton -# benchmarks these once per (D, L, FIR_LEN) and caches the winner, so no -# single tile size is hard-coded -- the GPU and shape pick it. -_AUTOTUNE_CONFIGS: list[triton.Config] = [ - triton.Config({"BLOCK_D": 32, "BLOCK_L": 64}, num_warps=2), - triton.Config({"BLOCK_D": 64, "BLOCK_L": 64}, num_warps=4), - triton.Config({"BLOCK_D": 64, "BLOCK_L": 128}, num_warps=4), - triton.Config({"BLOCK_D": 128, "BLOCK_L": 64}, num_warps=4), - triton.Config({"BLOCK_D": 64, "BLOCK_L": 256}, num_warps=8), - triton.Config({"BLOCK_D": 128, "BLOCK_L": 128}, num_warps=8), -] +from vortex.ops.triton_common import BDL_TILE_CONFIGS, bdl_grid_3d -@triton.autotune(configs=_AUTOTUNE_CONFIGS, key=["D", "L", "FIR_LEN"]) +@triton.autotune(configs=BDL_TILE_CONFIGS, key=["D", "L", "FIR_LEN"]) @triton.jit def _hcs_depthwise_conv_kernel( u_ptr, @@ -62,18 +49,21 @@ def _hcs_depthwise_conv_kernel( offs_l = pid_l * BLOCK_L + tl.arange(0, BLOCK_L) mask_d = offs_d < D mask_l = offs_l < L + tile_mask = mask_d[:, None] & mask_l[None, :] u_base = u_ptr + pid_b * stride_ub + offs_d[:, None] * stride_ud acc = tl.zeros((BLOCK_D, BLOCK_L), dtype=tl.float32) + # hcs_conv forces fp32 before launch, so loaded tiles are already fp32 -- + # the .to(tl.float32) calls are no-ops at runtime. for k in tl.static_range(FIR_LEN): w_k = tl.load( w_ptr + offs_d * stride_wd + k * stride_wk, mask=mask_d, other=0.0 ) pos = offs_l - (FIR_LEN - 1) + k - mask_pos = mask_d[:, None] & (pos[None, :] >= 0) & (pos[None, :] < L) + mask_pos = tile_mask & (pos[None, :] >= 0) & (pos[None, :] < L) u_tile = tl.load(u_base + pos[None, :] * stride_ul, mask=mask_pos, other=0.0) - acc += w_k[:, None].to(tl.float32) * u_tile.to(tl.float32) + acc += w_k[:, None] * u_tile z_ptrs = ( z_ptr @@ -81,7 +71,7 @@ def _hcs_depthwise_conv_kernel( + offs_d[:, None] * stride_ud + offs_l[None, :] * stride_ul ) - tl.store(z_ptrs, acc, mask=mask_d[:, None] & mask_l[None, :]) + tl.store(z_ptrs, acc, mask=tile_mask) def hcs_depthwise_conv(u: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: @@ -112,14 +102,7 @@ def hcs_depthwise_conv(u: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: raise ValueError(f"weight {tuple(weight.shape)} is not depthwise for D={D}") z: torch.Tensor = torch.empty_like(u) - # BLOCK_D / BLOCK_L are supplied by @triton.autotune; the grid is a - # callable so it can read the chosen tile sizes from the winning config. - grid: Callable[[triton.Config], tuple[int, int, int]] = lambda meta: ( - B, - triton.cdiv(D, meta["BLOCK_D"]), - triton.cdiv(L, meta["BLOCK_L"]), - ) - _hcs_depthwise_conv_kernel[grid]( + _hcs_depthwise_conv_kernel[bdl_grid_3d(B, D, L)]( u, weight, z, diff --git a/vortex/ops/triton_common.py b/vortex/ops/triton_common.py new file mode 100644 index 0000000..a21d343 --- /dev/null +++ b/vortex/ops/triton_common.py @@ -0,0 +1,47 @@ +""" +Shared autotune configs and grid helpers for the HC{S,M,L} Triton kernels. + +Pieces here are kept intentionally small. Only put a config or helper here +when it is genuinely identical across kernels -- any kernel that needs a +different tile or grid declares its own. Each @triton.autotune-decorated +kernel keeps its own benchmark cache, so sharing the config list is a +syntactic convenience, not a performance contract. +""" + +from collections.abc import Callable + +import triton + +# 2-D (BLOCK_D, BLOCK_L) tile sweep for memory-bound elementwise kernels over +# a (D, L) plane. Re-benchmarked per shape key by each decorated kernel. +BDL_TILE_CONFIGS: list[triton.Config] = [ + triton.Config({"BLOCK_D": 32, "BLOCK_L": 64}, num_warps=2), + triton.Config({"BLOCK_D": 64, "BLOCK_L": 64}, num_warps=4), + triton.Config({"BLOCK_D": 64, "BLOCK_L": 128}, num_warps=4), + triton.Config({"BLOCK_D": 128, "BLOCK_L": 64}, num_warps=4), + triton.Config({"BLOCK_D": 64, "BLOCK_L": 256}, num_warps=8), + triton.Config({"BLOCK_D": 128, "BLOCK_L": 128}, num_warps=8), +] + + +def bdl_grid_3d(B: int, D: int, L: int) -> Callable[[dict], tuple[int, int, int]]: + """ + Standard (B, cdiv(D, BLOCK_D), cdiv(L, BLOCK_L)) grid for the HC bias-residual + and HCS conv kernels. Reads tile sizes from the autotune meta-dict. + """ + return lambda meta: ( + B, + triton.cdiv(D, meta["BLOCK_D"]), + triton.cdiv(L, meta["BLOCK_L"]), + ) + + +def bdl_grid_2d(D: int, L: int) -> Callable[[dict], tuple[int, int]]: + """ + 2-D (cdiv(D, BLOCK_D), cdiv(L, BLOCK_L)) grid for kernels without a batch + axis (the HCL filter build). + """ + return lambda meta: ( + triton.cdiv(D, meta["BLOCK_D"]), + triton.cdiv(L, meta["BLOCK_L"]), + ) From a1dcc0f0795aa343f1d1cf5dd46977a7785d5042 Mon Sep 17 00:00:00 2001 From: alphakhaw Date: Wed, 20 May 2026 11:51:40 +0800 Subject: [PATCH 15/16] test: gpu/e2e/slow pytest markers and conftest auto-skip pyproject markers + conftest hook replace 16 @pytest.mark.skipif(not CUDA) decorators across the 12 HC test files. e2e tests additionally tagged @pytest.mark.e2e @pytest.mark.slow. pytest -m 'not e2e' 67 fast-loop tests pytest -m e2e 4 Evo2-loaded tests pytest -m 'gpu and not e2e' fast kernel + adapter + engine wiring --- pyproject.toml | 23 ++++++++++++++++------- test/conftest.py | 17 +++++++++++++++++ test/test_hcl_adapter.py | 4 +--- test/test_hcl_e2e.py | 11 +++++++---- test/test_hcl_engine.py | 8 +++----- test/test_hcl_kernel.py | 12 +++++------- test/test_hcm_adapter.py | 6 ++---- test/test_hcm_e2e.py | 7 ++++--- test/test_hcm_engine.py | 18 ++++++++++-------- test/test_hcm_kernel.py | 10 ++++------ test/test_hcs_adapter.py | 10 ++++------ test/test_hcs_e2e.py | 7 ++++--- test/test_hcs_engine.py | 10 ++++------ test/test_hcs_kernel.py | 8 +++----- 14 files changed, 84 insertions(+), 67 deletions(-) create mode 100644 test/conftest.py diff --git a/pyproject.toml b/pyproject.toml index dce54c9..232cae8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,19 +7,28 @@ authors = [{ name = "Michael Poli" }] requires-python = ">=3.10" license-files = ["LICENSE"] dependencies = [ - "torch", - "numpy", - "einops==0.8.1", - "packaging", - "rich", - "tqdm", - "PyYAML", + "torch", + "numpy", + "einops==0.8.1", + "packaging", + "rich", + "tqdm", + "PyYAML", ] [tool.setuptools.packages.find] where = ["."] include = ["vortex*"] +[tool.pytest.ini_options] +testpaths = ["test"] +markers = [ + "gpu: requires a CUDA device", + "e2e: end-to-end test that loads an Evo2 checkpoint (slow, large GPU memory)", + "slow: takes >10s", +] +filterwarnings = ["ignore::DeprecationWarning:torch\\..*"] + [tool.black] line-length = 119 diff --git a/test/conftest.py b/test/conftest.py new file mode 100644 index 0000000..a35e0ce --- /dev/null +++ b/test/conftest.py @@ -0,0 +1,17 @@ +import pytest +import torch + + +def pytest_collection_modifyitems( + config: pytest.Config, items: list[pytest.Item] +) -> None: + """ + Auto-skip tests marked @pytest.mark.gpu when no CUDA device is present. + """ + _ = config + if torch.cuda.is_available(): + return + skip_gpu = pytest.mark.skip(reason="requires CUDA device") + for item in items: + if "gpu" in item.keywords: + item.add_marker(skip_gpu) diff --git a/test/test_hcl_adapter.py b/test/test_hcl_adapter.py index 91b1c21..0c4176d 100644 --- a/test/test_hcl_adapter.py +++ b/test/test_hcl_adapter.py @@ -12,8 +12,6 @@ from vortex.ops.hcl_interface import hcl_fft_conv -CUDA: bool = torch.cuda.is_available() - def _hcl_branch_ref( h: torch.Tensor, @@ -33,7 +31,7 @@ def _hcl_branch_ref( return (y + x1v * D.unsqueeze(-1)) * x2 -@pytest.mark.skipif(not CUDA, reason="hcl_fft_conv requires CUDA") +@pytest.mark.gpu @pytest.mark.parametrize("L", [2048, 8192, 32768]) def test_hcl_fft_conv_matches_branch(L: int) -> None: """ diff --git a/test/test_hcl_e2e.py b/test/test_hcl_e2e.py index 0b3a378..323af89 100644 --- a/test/test_hcl_e2e.py +++ b/test/test_hcl_e2e.py @@ -22,7 +22,6 @@ from vortex.model.engine import HyenaInferenceEngine -CUDA: bool = torch.cuda.is_available() _MODEL_ID: str = os.environ.get("VK_E2E_MODEL", "evo2_7b") _SEQ_LEN: int = 2048 @@ -32,7 +31,7 @@ def evo2_model(): """ Load the Evo2 model once for the module, or skip if unavailable. """ - if not CUDA: + if not torch.cuda.is_available(): pytest.skip("Evo2 e2e test requires CUDA") try: from evo2 import Evo2 @@ -76,7 +75,9 @@ def _logits(model, input_ids: torch.Tensor) -> torch.Tensor: return out.float() -@pytest.mark.skipif(not CUDA, reason="Evo2 e2e test requires CUDA") +@pytest.mark.gpu +@pytest.mark.e2e +@pytest.mark.slow def test_vk_hcl_e2e_matches_baseline(evo2_model) -> None: """ A full Evo2 forward is behaviourally unchanged when use_hcl_kernel swaps @@ -111,7 +112,9 @@ def test_vk_hcl_e2e_matches_baseline(evo2_model) -> None: assert cosine >= 0.9999, f"use_hcl_kernel logits diverged: cosine={cosine:.6f}" -@pytest.mark.skipif(not CUDA, reason="Evo2 e2e test requires CUDA") +@pytest.mark.gpu +@pytest.mark.e2e +@pytest.mark.slow def test_hcl_unlocks_131k(evo2_model) -> None: """ A full evo2_7b forward at L=131072 completes with use_hcl_kernel on. diff --git a/test/test_hcl_engine.py b/test/test_hcl_engine.py index 595a5f0..f420dd7 100644 --- a/test/test_hcl_engine.py +++ b/test/test_hcl_engine.py @@ -11,8 +11,6 @@ from vortex.model.engine import HyenaInferenceEngine -CUDA: bool = torch.cuda.is_available() - # evo2_7b HCL shapes: D=4096, state_size=16. B, D, S = 1, 4096, 16 DIMS: tuple[int, int, int, int, int] = (D, 32, D // 32, 16, 256) @@ -50,7 +48,7 @@ def _call(engine, z_pre, h, bias, L, poles, residues, t, **kw): ) -@pytest.mark.skipif(not CUDA, reason="HCL kernel requires CUDA") +@pytest.mark.gpu @pytest.mark.parametrize("L", [2048, 8192]) def test_vk_hcl_on_matches_baseline_fp32(L: int) -> None: """ @@ -68,7 +66,7 @@ def test_vk_hcl_on_matches_baseline_fp32(L: int) -> None: assert (y_on - y_off).abs().max().item() < 1e-2 -@pytest.mark.skipif(not CUDA, reason="HCL kernel requires CUDA") +@pytest.mark.gpu def test_vk_hcl_off_by_default() -> None: """ A fresh HyenaInferenceEngine has use_hcl_kernel False -- the stock path. @@ -86,7 +84,7 @@ def test_vk_hcl_off_by_default() -> None: assert (y_default - y_explicit).abs().max().item() == 0.0 -@pytest.mark.skipif(not CUDA, reason="HCL kernel requires CUDA") +@pytest.mark.gpu def test_vk_hcl_predicate_skips_long_fir() -> None: """ The branch matches only long_fir_threshold is None -- a set threshold diff --git a/test/test_hcl_kernel.py b/test/test_hcl_kernel.py index 791f409..709e363 100644 --- a/test/test_hcl_kernel.py +++ b/test/test_hcl_kernel.py @@ -14,8 +14,6 @@ from vortex.ops.hcl_interface import _hcl_bias_residual_gate, _hcl_compute_filter -CUDA: bool = torch.cuda.is_available() - def _modal_filter_ref( residues: torch.Tensor, log_poles: torch.Tensor, t: torch.Tensor @@ -42,7 +40,7 @@ def _hcl_inputs( return residues, log_poles, t -@pytest.mark.skipif(not CUDA, reason="HCL Triton kernel requires CUDA") +@pytest.mark.gpu @pytest.mark.parametrize("L", [2048, 8192, 32768]) def test_hcl_compute_filter_matches_oracle(L: int) -> None: """ @@ -62,7 +60,7 @@ def test_hcl_compute_filter_matches_oracle(L: int) -> None: assert mean_diff < 1e-4, f"mean_diff={mean_diff:.2e}" -@pytest.mark.skipif(not CUDA, reason="HCL Triton kernel requires CUDA") +@pytest.mark.gpu def test_hcl_compute_filter_masks_ragged_tile() -> None: """ The kernel masks (D, L) tiles that BLOCK_D x BLOCK_L does not divide evenly. @@ -74,7 +72,7 @@ def test_hcl_compute_filter_masks_ragged_tile() -> None: assert (h - h_ref).abs().max().item() < 1e-3 -@pytest.mark.skipif(not CUDA, reason="HCL Triton kernel requires CUDA") +@pytest.mark.gpu def test_hcl_compute_filter_avoids_the_intermediate() -> None: """ The kernel's peak allocation stays well below the reference path, which @@ -101,7 +99,7 @@ def test_hcl_compute_filter_avoids_the_intermediate() -> None: ) -@pytest.mark.skipif(not CUDA, reason="HCL Triton kernel requires CUDA") +@pytest.mark.gpu @pytest.mark.parametrize("B", [1, 2]) @pytest.mark.parametrize("L", [2048, 8192, 32768]) def test_hcl_bias_residual_gate_matches_oracle(B: int, L: int) -> None: @@ -127,7 +125,7 @@ def test_hcl_bias_residual_gate_matches_oracle(B: int, L: int) -> None: assert mean_diff < 1e-5, f"mean_diff={mean_diff:.2e}" -@pytest.mark.skipif(not CUDA, reason="HCL Triton kernel requires CUDA") +@pytest.mark.gpu def test_hcl_bias_residual_gate_masks_ragged_tile() -> None: """ The kernel masks (D, L) tiles that BLOCK_D x BLOCK_L does not divide evenly. diff --git a/test/test_hcm_adapter.py b/test/test_hcm_adapter.py index 3ca671e..4eabea7 100644 --- a/test/test_hcm_adapter.py +++ b/test/test_hcm_adapter.py @@ -12,8 +12,6 @@ from vortex.model.engine import fftconv_func from vortex.ops.hcm_interface import hcm_fft_conv -CUDA: bool = torch.cuda.is_available() - def _hcm_inputs(B: int, L: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ @@ -27,7 +25,7 @@ def _hcm_inputs(B: int, L: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tenso return u, weight, bias -@pytest.mark.skipif(not CUDA, reason="hcm_fft_conv requires CUDA") +@pytest.mark.gpu @pytest.mark.parametrize("B", [1, 2]) @pytest.mark.parametrize("L", [2048, 8192, 32768]) def test_hcm_fft_conv_matches_fftconv_func(B: int, L: int) -> None: @@ -48,7 +46,7 @@ def test_hcm_fft_conv_matches_fftconv_func(B: int, L: int) -> None: assert mean_diff < 1e-3, f"mean_diff={mean_diff:.2e}" -@pytest.mark.skipif(not CUDA, reason="hcm_fft_conv requires CUDA") +@pytest.mark.gpu def test_hcm_fft_conv_rejects_unsupported_paths() -> None: """ hcm_fft_conv raises on bidirectional or reverse-filter calls -- paths the diff --git a/test/test_hcm_e2e.py b/test/test_hcm_e2e.py index bf30166..c7e16ae 100644 --- a/test/test_hcm_e2e.py +++ b/test/test_hcm_e2e.py @@ -19,7 +19,6 @@ from vortex.model.engine import HyenaInferenceEngine -CUDA: bool = torch.cuda.is_available() _MODEL_ID: str = os.environ.get("VK_E2E_MODEL", "evo2_7b") _SEQ_LEN: int = 2048 @@ -29,7 +28,7 @@ def evo2_model(): """ Load the Evo2 model once for the module, or skip if unavailable. """ - if not CUDA: + if not torch.cuda.is_available(): pytest.skip("Evo2 e2e test requires CUDA") try: from evo2 import Evo2 @@ -79,7 +78,9 @@ def _logits(model, input_ids: torch.Tensor) -> torch.Tensor: return out.float() -@pytest.mark.skipif(not CUDA, reason="Evo2 e2e test requires CUDA") +@pytest.mark.gpu +@pytest.mark.e2e +@pytest.mark.slow def test_vk_hcm_e2e_matches_baseline(evo2_model) -> None: """ A full Evo2 forward is behaviourally unchanged when use_hcm_kernel swaps diff --git a/test/test_hcm_engine.py b/test/test_hcm_engine.py index 7b83b8a..7f3fa5e 100644 --- a/test/test_hcm_engine.py +++ b/test/test_hcm_engine.py @@ -13,8 +13,6 @@ from vortex.model.engine import HyenaInferenceEngine -CUDA: bool = torch.cuda.is_available() - # evo2_7b HCM cascade shapes: D=4096, fir_length=128. B, D, K, GROUPS = 1, 4096, 128, 256 DIMS: tuple[int, int, int, int, int] = (D, 32, D // 32, 16, GROUPS) @@ -38,7 +36,7 @@ def _hcm_inputs(L: int, dtype: torch.dtype): return u, weight, bias -@pytest.mark.skipif(not CUDA, reason="HCM kernel requires CUDA") +@pytest.mark.gpu @pytest.mark.parametrize("L", [2048, 8192]) def test_vk_hcm_on_matches_baseline_fp32(L: int) -> None: """ @@ -57,7 +55,7 @@ def test_vk_hcm_on_matches_baseline_fp32(L: int) -> None: assert (z_on - z_off).abs().max().item() < 1e-2 -@pytest.mark.skipif(not CUDA, reason="HCM kernel requires CUDA") +@pytest.mark.gpu def test_vk_hcm_on_matches_baseline_bf16() -> None: """ use_hcm_kernel on reproduces the stock HCM output in bf16, the inference dtype. @@ -74,7 +72,7 @@ def test_vk_hcm_on_matches_baseline_bf16() -> None: torch.testing.assert_close(z_on, z_off, rtol=2e-2, atol=2e-2) -@pytest.mark.skipif(not CUDA, reason="HCM kernel requires CUDA") +@pytest.mark.gpu def test_vk_hcm_off_by_default() -> None: """ A fresh HyenaInferenceEngine has use_hcm_kernel False, so parallel_fir @@ -85,16 +83,20 @@ def test_vk_hcm_off_by_default() -> None: assert engine.use_hcm_kernel is False u, weight, bias = _hcm_inputs(2048, torch.float32) - z_default, _ = engine.parallel_fir(F.conv1d, u, weight, bias, 2048, DIMS, **_CASCADE_KW) + z_default, _ = engine.parallel_fir( + F.conv1d, u, weight, bias, 2048, DIMS, **_CASCADE_KW + ) explicit_off = HyenaInferenceEngine(layer_idx=0, use_hcm_kernel=False) - z_explicit, _ = explicit_off.parallel_fir(F.conv1d, u, weight, bias, 2048, DIMS, **_CASCADE_KW) + z_explicit, _ = explicit_off.parallel_fir( + F.conv1d, u, weight, bias, 2048, DIMS, **_CASCADE_KW + ) # both took the stock path, so they are bitwise identical assert (z_default - z_explicit).abs().max().item() == 0.0 -@pytest.mark.skipif(not CUDA, reason="HCM kernel requires CUDA") +@pytest.mark.gpu def test_vk_hcm_predicate_skips_hcs_calls() -> None: """ The branch matches only the fir_length >= 128 cascade -- an HCS-length diff --git a/test/test_hcm_kernel.py b/test/test_hcm_kernel.py index 35cb7f0..8a41349 100644 --- a/test/test_hcm_kernel.py +++ b/test/test_hcm_kernel.py @@ -14,8 +14,6 @@ from vortex.model.engine import adjust_filter_shape_for_broadcast from vortex.ops.hcm_interface import _hcm_bias_residual, _hcm_complex_mul -CUDA: bool = torch.cuda.is_available() - def _spectra( B: int, D: int, L: int, fir_length: int @@ -37,7 +35,7 @@ def _spectra( return u_f, k_f, fft_size -@pytest.mark.skipif(not CUDA, reason="HCM Triton kernel requires CUDA") +@pytest.mark.gpu @pytest.mark.parametrize("B", [1, 2]) @pytest.mark.parametrize("L", [1024, 8192, 32768]) def test_hcm_complex_mul_matches_oracle(B: int, L: int) -> None: @@ -58,7 +56,7 @@ def test_hcm_complex_mul_matches_oracle(B: int, L: int) -> None: assert mean_diff < 1e-5, f"mean_diff={mean_diff:.2e}" -@pytest.mark.skipif(not CUDA, reason="HCM Triton kernel requires CUDA") +@pytest.mark.gpu def test_hcm_complex_mul_masks_ragged_tail() -> None: """ The kernel masks the flat (D, F) tail when D*F does not divide the tile. @@ -70,7 +68,7 @@ def test_hcm_complex_mul_masks_ragged_tail() -> None: assert (y_f - y_f_ref).abs().max().item() < 1e-4 -@pytest.mark.skipif(not CUDA, reason="HCM Triton kernel requires CUDA") +@pytest.mark.gpu @pytest.mark.parametrize("B", [1, 2]) @pytest.mark.parametrize("L", [1024, 8192, 32768]) def test_hcm_bias_residual_matches_oracle(B: int, L: int) -> None: @@ -95,7 +93,7 @@ def test_hcm_bias_residual_matches_oracle(B: int, L: int) -> None: assert mean_diff < 1e-5, f"mean_diff={mean_diff:.2e}" -@pytest.mark.skipif(not CUDA, reason="HCM Triton kernel requires CUDA") +@pytest.mark.gpu def test_hcm_bias_residual_masks_ragged_tile() -> None: """ The kernel masks (D, L) tiles that BLOCK_D x BLOCK_L does not divide evenly. diff --git a/test/test_hcs_adapter.py b/test/test_hcs_adapter.py index bdf9cbc..b196103 100644 --- a/test/test_hcs_adapter.py +++ b/test/test_hcs_adapter.py @@ -13,8 +13,6 @@ from vortex.ops.hcs_interface import hcs_conv -CUDA: bool = torch.cuda.is_available() - def _hcs_branch_ref(x1, x2, v, weight, bias, gated_bias, padding_mask): """ @@ -38,7 +36,7 @@ def _hcs_branch_ref(x1, x2, v, weight, bias, gated_bias, padding_mask): return x2 * z -@pytest.mark.skipif(not CUDA, reason="hcs_conv requires CUDA") +@pytest.mark.gpu @pytest.mark.parametrize("L", [1024, 8192]) @pytest.mark.parametrize("with_bias", [True, False]) def test_hcs_conv_matches_engine_branch(L: int, with_bias: bool) -> None: @@ -58,7 +56,7 @@ def test_hcs_conv_matches_engine_branch(L: int, with_bias: bool) -> None: assert (z - z_ref).abs().max().item() < 1e-3 -@pytest.mark.skipif(not CUDA, reason="hcs_conv requires CUDA") +@pytest.mark.gpu def test_hcs_conv_gated_bias() -> None: """ The gated_bias=True path applies the bias multiplicatively. @@ -76,7 +74,7 @@ def test_hcs_conv_gated_bias() -> None: assert (z - z_ref).abs().max().item() < 1e-3 -@pytest.mark.skipif(not CUDA, reason="hcs_conv requires CUDA") +@pytest.mark.gpu def test_hcs_conv_padding_mask() -> None: """ A padding_mask tensor zeros masked positions in the output. @@ -96,7 +94,7 @@ def test_hcs_conv_padding_mask() -> None: assert z[..., L // 2 :].abs().max().item() == 0.0 -@pytest.mark.skipif(not CUDA, reason="hcs_conv requires CUDA") +@pytest.mark.gpu def test_hcs_conv_bf16() -> None: """ hcs_conv matches the engine branch in bf16, the real inference dtype. diff --git a/test/test_hcs_e2e.py b/test/test_hcs_e2e.py index 0a07409..3538ce0 100644 --- a/test/test_hcs_e2e.py +++ b/test/test_hcs_e2e.py @@ -17,7 +17,6 @@ from vortex.model.engine import HyenaInferenceEngine -CUDA: bool = torch.cuda.is_available() _MODEL_ID: str = os.environ.get("VK_E2E_MODEL", "evo2_7b") _SEQ_LEN: int = 2048 @@ -27,7 +26,7 @@ def evo2_model(): """ Load the Evo2 model once for the module, or skip if unavailable. """ - if not CUDA: + if not torch.cuda.is_available(): pytest.skip("Evo2 e2e test requires CUDA") try: from evo2 import Evo2 @@ -77,7 +76,9 @@ def _logits(model, input_ids: torch.Tensor) -> torch.Tensor: return out.float() -@pytest.mark.skipif(not CUDA, reason="Evo2 e2e test requires CUDA") +@pytest.mark.gpu +@pytest.mark.e2e +@pytest.mark.slow def test_vk_hcs_e2e_matches_baseline(evo2_model) -> None: """ A full Evo2 forward is unchanged when use_hcs_kernel swaps in the kernel. diff --git a/test/test_hcs_engine.py b/test/test_hcs_engine.py index 291e579..8bf712a 100644 --- a/test/test_hcs_engine.py +++ b/test/test_hcs_engine.py @@ -13,8 +13,6 @@ from vortex.model.engine import HyenaInferenceEngine -CUDA: bool = torch.cuda.is_available() - # evo2_7b HCS cascade shapes: D=4096, hcs_filter_groups=256, fir_length=7. B, D, K, GROUPS = 1, 4096, 7, 256 DIMS: tuple[int, int, int, int, int] = (D, 32, D // 32, 16, GROUPS) @@ -38,7 +36,7 @@ def _hcs_inputs(L: int, dtype: torch.dtype): return u, weight, bias -@pytest.mark.skipif(not CUDA, reason="HCS kernel requires CUDA") +@pytest.mark.gpu @pytest.mark.parametrize("L", [1024, 8192]) def test_vk_hcs_on_matches_baseline_fp32(L: int) -> None: """ @@ -62,7 +60,7 @@ def test_vk_hcs_on_matches_baseline_fp32(L: int) -> None: assert (z_on - z_off).abs().max().item() < 1e-3 -@pytest.mark.skipif(not CUDA, reason="HCS kernel requires CUDA") +@pytest.mark.gpu def test_vk_hcs_on_matches_baseline_bf16() -> None: """ use_hcs_kernel on reproduces the stock HCS output in bf16, the inference dtype. @@ -79,7 +77,7 @@ def test_vk_hcs_on_matches_baseline_bf16() -> None: torch.testing.assert_close(z_on, z_off, rtol=2e-2, atol=2e-2) -@pytest.mark.skipif(not CUDA, reason="HCS kernel requires CUDA") +@pytest.mark.gpu def test_vk_hcs_off_by_default() -> None: """ A fresh HyenaInferenceEngine has use_hcs_kernel False, so parallel_fir @@ -103,7 +101,7 @@ def test_vk_hcs_off_by_default() -> None: assert (z_default - z_explicit).abs().max().item() == 0.0 -@pytest.mark.skipif(not CUDA, reason="HCS kernel requires CUDA") +@pytest.mark.gpu def test_vk_hcs_predicate_skips_non_hcs_calls() -> None: """ The branch matches only the gated short cascade -- not the featurizer diff --git a/test/test_hcs_kernel.py b/test/test_hcs_kernel.py index 9f67979..ec467e0 100644 --- a/test/test_hcs_kernel.py +++ b/test/test_hcs_kernel.py @@ -12,8 +12,6 @@ from vortex.ops.hcs_interface import hcs_depthwise_conv -CUDA: bool = torch.cuda.is_available() - def _conv1d_ref(u: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: """ @@ -26,7 +24,7 @@ def _conv1d_ref(u: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: )[..., :L] -@pytest.mark.skipif(not CUDA, reason="HCS Triton kernel requires CUDA") +@pytest.mark.gpu @pytest.mark.parametrize("L", [1024, 8192, 32768]) @pytest.mark.parametrize("fir_length", [3, 7]) def test_hcs_conv_matches_conv1d(L: int, fir_length: int) -> None: @@ -47,7 +45,7 @@ def test_hcs_conv_matches_conv1d(L: int, fir_length: int) -> None: assert mean_diff < 1e-4, f"mean_diff={mean_diff:.2e}" -@pytest.mark.skipif(not CUDA, reason="HCS Triton kernel requires CUDA") +@pytest.mark.gpu def test_hcs_conv_is_causal() -> None: """ Output position t depends only on inputs at or before t. @@ -71,7 +69,7 @@ def test_hcs_conv_is_causal() -> None: assert after > 0.0, "perturbation had no effect on later positions" -@pytest.mark.skipif(not CUDA, reason="HCS Triton kernel requires CUDA") +@pytest.mark.gpu @pytest.mark.parametrize("D", [16, 4096, 4100]) def test_hcs_conv_ragged_channels(D: int) -> None: """ From 9c54bb31c71ef1721355988fb19c5f95d17c4216 Mon Sep 17 00:00:00 2001 From: Alpha Khaw Date: Thu, 21 May 2026 01:10:01 +0800 Subject: [PATCH 16/16] docs: tighten kernel-module docstrings; collapse hcs contiguity guards --- vortex/ops/hcl_interface.py | 12 ++++-------- vortex/ops/hcm_interface.py | 16 ++++++---------- vortex/ops/hcs_interface.py | 34 ++++++++++++++-------------------- vortex/ops/triton_common.py | 9 ++++----- 4 files changed, 28 insertions(+), 43 deletions(-) diff --git a/vortex/ops/hcl_interface.py b/vortex/ops/hcl_interface.py index 6127d27..6a73aa9 100644 --- a/vortex/ops/hcl_interface.py +++ b/vortex/ops/hcl_interface.py @@ -242,14 +242,10 @@ def hcl_fft_conv( """ Fused HCL FFT-convolution epilogue. - Reproduces parallel_iir's long_fir_threshold-is-None branch in one path: - cuFFT keeps the three transforms, _hcm_complex_mul does the spectral - product X*H scaled by 1/fft_size, and _hcl_bias_residual_gate does the - post-conv (y + x1v*D[:, None]) * x2. - - The signal FFT uses rfft -- mathematically identical to fft on a real - input but produces only fft_size//2 + 1 bins, so cuFFT does ~half the - work the stock parallel_iir's `fft` + slice path does. + Reproduces parallel_iir's long_fir_threshold-is-None branch. cuFFT keeps + the three transforms; _hcm_complex_mul does the spectral product X*H + scaled by 1/fft_size; _hcl_bias_residual_gate does the post-conv + (y + x1v*D[:, None]) * x2. Uses rfft (half the bins of fft on real input). Args: h (torch.Tensor): The modal filter, shape (1, D, L). diff --git a/vortex/ops/hcm_interface.py b/vortex/ops/hcm_interface.py index 69a1afc..d2b72dc 100644 --- a/vortex/ops/hcm_interface.py +++ b/vortex/ops/hcm_interface.py @@ -246,15 +246,12 @@ def hcm_fft_conv( **kwargs, ) -> torch.Tensor: """ - Fused HCM FFT-convolution -- a drop-in for fftconv_func. + Fused HCM FFT-convolution -- drop-in for fftconv_func. - Reproduces fftconv_func's non-bidirectional inference path with the - elementwise glue fused into Triton kernels. cuFFT keeps the three - transforms; _hcm_complex_mul does the scaled spectral product and - _hcm_bias_residual does the skip-residual add. The signature mirrors - fftconv_func so the engine dispatch is a one-line swap; the trailing - args (dropout_mask, gelu, k_rev, bidirectional, print_activations, - layer_idx, **kwargs) exist only for that parity. + cuFFT keeps the three transforms; _hcm_complex_mul does stage 3 (scaled + spectral product), _hcm_bias_residual does stage 5 (skip-residual add). + Trailing kwargs exist only for signature parity with fftconv_func so the + engine dispatch is a one-line swap. Args: u (torch.Tensor): Input activations, shape (B, D, L). @@ -265,8 +262,7 @@ def hcm_fft_conv( torch.Tensor: y + u * D[:, None], shape (B, D, L), u's dtype. Raises: - NotImplementedError: If bidirectional is True or k_rev is set; - unsupported paths the HCM dispatch never hits. + NotImplementedError: If bidirectional is True or k_rev is set. """ if bidirectional or k_rev is not None: raise NotImplementedError( diff --git a/vortex/ops/hcs_interface.py b/vortex/ops/hcs_interface.py index e6c2f85..d3ce4da 100644 --- a/vortex/ops/hcs_interface.py +++ b/vortex/ops/hcs_interface.py @@ -1,13 +1,11 @@ """ HCS -- Hyena Cascade Short. -A from-scratch Triton depthwise causal 1D convolution for the short-filter -(fir_length < 128) gated branch of HyenaInferenceEngine.parallel_fir. - -The convolution is the only time-mixing op in an HCS layer: a depthwise -filter of fir_length taps (7 in evo2_7b) applied per channel. This module -provides the @triton.jit kernel and a thin Python launcher; the hcs_conv -adapter that wires it behind the use_hcs_kernel config flag is added alongside. +Triton depthwise causal 1D conv for the short-filter (fir_length < 128) +gated branch of HyenaInferenceEngine.parallel_fir. A depthwise filter of +fir_length taps (7 in evo2_7b) applied per channel; the only time-mixing +op in an HCS layer. Exposes the @triton.jit kernel and the hcs_conv +adapter wired behind use_hcs_kernel. """ import torch @@ -89,10 +87,8 @@ def hcs_depthwise_conv(u: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: Returns: torch.Tensor: Convolved output, shape (B, D, L), same dtype as u. """ - if not u.is_contiguous(): - u = u.contiguous() - if not weight.is_contiguous(): - weight = weight.contiguous() + u = u.contiguous() + weight = weight.contiguous() if u.dim() != 3 or weight.dim() != 3: raise ValueError(f"expected 3-D u and weight, got {u.shape} and {weight.shape}") @@ -131,11 +127,9 @@ def hcs_conv( """ Fully-gated HCS short conv: z = x2 * (conv(x1 * v, weight) + bias). - Drop-in replacement for the gated fir_length < 128 branch of - HyenaInferenceEngine.parallel_fir. It reproduces that branch exactly: - the depthwise conv runs in fp32 for numerical parity with the F.conv1d - path, the result is cast back to the activation dtype, bias-added, - masked, then closed with the post-gate multiply by x2. + Drop-in for the gated fir_length < 128 branch of + HyenaInferenceEngine.parallel_fir. Conv runs in fp32 for parity with + F.conv1d, then casts back to x1.dtype. Args: x1 (torch.Tensor): Pre-gate "key" stream, shape (B, D, L). @@ -143,10 +137,10 @@ def hcs_conv( v (torch.Tensor): "Value" stream, shape (B, D, L). weight (torch.Tensor): Depthwise filter, shape (D, 1, fir_length). bias (torch.Tensor | None): Per-channel skip-gain, shape (D,). - gated_bias (bool): If True the bias is applied multiplicatively - (bias * x1 * v); HCS uses additive bias (False). - padding_mask (torch.Tensor | None): If a tensor, zeros masked - positions after the conv, shape (B, L). + gated_bias (bool): If True, bias is applied multiplicatively + (bias * x1 * v); HCS uses additive (False). + padding_mask (torch.Tensor | None): If set, zeros masked positions + after the conv, shape (B, L). Returns: torch.Tensor: Gated HCS output, shape (B, D, L), x1's dtype. diff --git a/vortex/ops/triton_common.py b/vortex/ops/triton_common.py index a21d343..2cc3f58 100644 --- a/vortex/ops/triton_common.py +++ b/vortex/ops/triton_common.py @@ -1,11 +1,10 @@ """ Shared autotune configs and grid helpers for the HC{S,M,L} Triton kernels. -Pieces here are kept intentionally small. Only put a config or helper here -when it is genuinely identical across kernels -- any kernel that needs a -different tile or grid declares its own. Each @triton.autotune-decorated -kernel keeps its own benchmark cache, so sharing the config list is a -syntactic convenience, not a performance contract. +BDL_TILE_CONFIGS is the (BLOCK_D, BLOCK_L) sweep for memory-bound elementwise +kernels over a (D, L) plane. Each @triton.autotune-decorated kernel keeps +its own benchmark cache, so sharing the config list is a syntactic +convenience, not a performance contract. """ from collections.abc import Callable