Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,28 @@ authors = [{ name = "Michael Poli" }]
requires-python = ">=3.10"
license-files = ["LICENSE"]
dependencies = [
"torch",
"numpy",
"einops==0.8.1",
"packaging",
"rich",
"tqdm",
"PyYAML",
"torch",
"numpy",
"einops==0.8.1",
"packaging",
"rich",
"tqdm",
"PyYAML",
]

[tool.setuptools.packages.find]
where = ["."]
include = ["vortex*"]

[tool.pytest.ini_options]
testpaths = ["test"]
markers = [
"gpu: requires a CUDA device",
"e2e: end-to-end test that loads an Evo2 checkpoint (slow, large GPU memory)",
"slow: takes >10s",
]
filterwarnings = ["ignore::DeprecationWarning:torch\\..*"]

[tool.black]
line-length = 119

Expand Down
17 changes: 17 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import pytest
import torch


def pytest_collection_modifyitems(
config: pytest.Config, items: list[pytest.Item]
) -> None:
"""
Auto-skip tests marked @pytest.mark.gpu when no CUDA device is present.
"""
_ = config
if torch.cuda.is_available():
return
skip_gpu = pytest.mark.skip(reason="requires CUDA device")
for item in items:
if "gpu" in item.keywords:
item.add_marker(skip_gpu)
56 changes: 56 additions & 0 deletions test/test_hcl_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""
Tests for hcl_fft_conv -- the fused FFT-conv epilogue for parallel_iir's
long_fir_threshold-is-None (HCL) branch.

The oracle is the stock branch: rfft(h)/fft_size, fft(x1v), X*H, irfft, then
the post-conv (y + x1v*D[:, None]) * x2. hcl_fft_conv reproduces it with the
_hcm_complex_mul and _hcl_bias_residual_gate kernels.
"""

import pytest
import torch

from vortex.ops.hcl_interface import hcl_fft_conv


def _hcl_branch_ref(
h: torch.Tensor,
x1v: torch.Tensor,
x2: torch.Tensor,
D: torch.Tensor,
L: int,
fft_size: int,
) -> torch.Tensor:
"""
Pure-torch reference for parallel_iir's HCL FFT-conv branch + post-conv.
"""
H = torch.fft.rfft(h.to(torch.float32), n=fft_size) / fft_size
X = torch.fft.fft(x1v.to(torch.float32), n=fft_size)[..., : H.shape[-1]]
y = torch.fft.irfft(X * H, n=fft_size, norm="forward")[..., :L]
y = y.to(x1v.dtype)
return (y + x1v * D.unsqueeze(-1)) * x2


@pytest.mark.gpu
@pytest.mark.parametrize("L", [2048, 8192, 32768])
def test_hcl_fft_conv_matches_branch(L: int) -> None:
"""
hcl_fft_conv reproduces the stock parallel_iir HCL branch in fp32.
"""
torch.manual_seed(0)
D = 4096
fft_size = 2 * L
h = torch.randn(1, D, L, dtype=torch.float32, device="cuda")
x1v = torch.randn(1, D, L, dtype=torch.float32, device="cuda")
x2 = torch.randn(1, D, L, dtype=torch.float32, device="cuda")
bias = torch.randn(D, dtype=torch.float32, device="cuda")

y = hcl_fft_conv(h, x1v, x2, bias, L, fft_size)
y_ref = _hcl_branch_ref(h, x1v, x2, bias, L, fft_size)

assert y.shape == y_ref.shape == (1, D, L)
assert y.dtype == y_ref.dtype
max_diff = (y - y_ref).abs().max().item()
mean_diff = (y - y_ref).abs().mean().item()
assert max_diff < 1e-2, f"max_diff={max_diff:.2e}"
assert mean_diff < 1e-3, f"mean_diff={mean_diff:.2e}"
141 changes: 141 additions & 0 deletions test/test_hcl_e2e.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
"""
End-to-end tests for the use_hcl_kernel flag inside a real Evo2 model.

Loads an Evo2 checkpoint once and runs the same forward with the HCL kernel
off (stock compute_filter + the parallel_iir FFT branch) and on (the tiled
filter build + the fused FFT-conv epilogue). The kernel is numerically
equivalent, not bit-exact, so the behavioural invariant is tested: the model
predicts the same tokens and the logit vectors stay near-parallel.

test_hcl_unlocks_131k is the headline -- the stock compute_filter OOMs at
L=131k on its (D, state_size, L) intermediate; the tiled kernel removes it.
It needs an ~80GB GPU (the sign-off pod) and skips on a smaller card.

These load a large checkpoint -- slow relative to the kernel tests; run them
deliberately, not in the fast loop.
"""

import os

import pytest
import torch

from vortex.model.engine import HyenaInferenceEngine

_MODEL_ID: str = os.environ.get("VK_E2E_MODEL", "evo2_7b")
_SEQ_LEN: int = 2048


@pytest.fixture(scope="module")
def evo2_model():
"""
Load the Evo2 model once for the module, or skip if unavailable.
"""
if not torch.cuda.is_available():
pytest.skip("Evo2 e2e test requires CUDA")
try:
from evo2 import Evo2
except ImportError as exc: # pragma: no cover - depends on optional dep
pytest.skip(f"evo2 not installed: {exc}")
try:
return Evo2(_MODEL_ID)
except Exception as exc: # noqa: BLE001 - any load failure -> skip, not fail
pytest.skip(f"could not load Evo2({_MODEL_ID!r}): {exc}")


def _set_hcl_kernel(model, enabled: bool) -> int:
"""
Flip use_hcl_kernel on every HyenaInferenceEngine reachable from model.

Args:
model: A loaded Evo2 model.
enabled (bool): Target value for use_hcl_kernel.

Returns:
The number of HyenaInferenceEngine instances touched.
"""
root = getattr(model, "model", model)
touched = 0
for module in root.modules():
engine = getattr(module, "engine", None)
if isinstance(engine, HyenaInferenceEngine):
engine.use_hcl_kernel = enabled
touched += 1
return touched


def _logits(model, input_ids: torch.Tensor) -> torch.Tensor:
"""
Run a forward pass and return the logits tensor as fp32.
"""
with torch.no_grad():
out = model(input_ids)
while isinstance(out, (tuple, list)):
out = out[0]
return out.float()


@pytest.mark.gpu
@pytest.mark.e2e
@pytest.mark.slow
def test_vk_hcl_e2e_matches_baseline(evo2_model) -> None:
"""
A full Evo2 forward is behaviourally unchanged when use_hcl_kernel swaps
in the tiled filter build and the fused FFT-conv epilogue.

The kernel is numerically equivalent to the stock path but not bit-exact,
so the test asserts prediction agreement (argmax + cosine), not an
absolute logit bound -- the same rationale as the HCM e2e.
"""
torch.manual_seed(0)
input_ids = torch.randint(1, 5, (1, _SEQ_LEN), dtype=torch.int, device="cuda:0")

try:
touched = _set_hcl_kernel(evo2_model, False)
assert touched > 0, "no HyenaInferenceEngine found in the Evo2 model"
logits_off = _logits(evo2_model, input_ids)

_set_hcl_kernel(evo2_model, True)
logits_on = _logits(evo2_model, input_ids)
finally:
_set_hcl_kernel(evo2_model, False)

assert logits_on.shape == logits_off.shape

agreement = (logits_on.argmax(-1) == logits_off.argmax(-1)).float().mean().item()
cosine = torch.nn.functional.cosine_similarity(
logits_on.flatten(), logits_off.flatten(), dim=0
).item()
assert agreement >= 0.99, (
f"use_hcl_kernel changed {(1 - agreement) * 100:.2f}% of token predictions"
)
assert cosine >= 0.9999, f"use_hcl_kernel logits diverged: cosine={cosine:.6f}"


@pytest.mark.gpu
@pytest.mark.e2e
@pytest.mark.slow
def test_hcl_unlocks_131k(evo2_model) -> None:
"""
A full evo2_7b forward at L=131072 completes with use_hcl_kernel on.

The stock compute_filter OOMs at this length on its (D, state_size, L)
fp32 intermediate (34 GiB at D=4096); the tiled kernel removes it. Needs
an ~80GB GPU (the H100 sign-off pod) -- on a smaller card this skips:
even with the kernel, the rest of the 131k forward will not fit.
"""
torch.manual_seed(0)
input_ids = torch.randint(1, 5, (1, 131072), dtype=torch.int, device="cuda:0")

try:
touched = _set_hcl_kernel(evo2_model, True)
assert touched > 0, "no HyenaInferenceEngine found in the Evo2 model"
logits = _logits(evo2_model, input_ids)
except torch.cuda.OutOfMemoryError:
torch.cuda.empty_cache()
pytest.skip("L=131072 needs ~80GB; run on the H100 sign-off pod")
finally:
_set_hcl_kernel(evo2_model, False)

assert logits.shape[1] == 131072
assert torch.isfinite(logits).all()
98 changes: 98 additions & 0 deletions test/test_hcl_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""
Wiring tests for the use_hcl_kernel branch in HyenaInferenceEngine.parallel_iir.

These exercise the long_fir_threshold-is-None dispatch by calling parallel_iir
directly. With the kernel enabled, the HCL FFT-conv must reproduce the stock
branch; the branch must not fire when long_fir_threshold is set.
"""

import pytest
import torch

from vortex.model.engine import HyenaInferenceEngine

# evo2_7b HCL shapes: D=4096, state_size=16.
B, D, S = 1, 4096, 16
DIMS: tuple[int, int, int, int, int] = (D, 32, D // 32, 16, 256)


def _hcl_inputs(L: int, dtype: torch.dtype):
"""
Build (z_pre, h, bias, poles, residues, t) for an HCL parallel_iir call.
"""
z_pre = torch.randn(B, 3 * D, L, dtype=dtype, device="cuda")
h = torch.randn(1, D, L, dtype=torch.float32, device="cuda")
bias = torch.randn(D, dtype=dtype, device="cuda")
poles = torch.randn(D, S, 1, dtype=torch.float32, device="cuda")
residues = torch.randn(D, S, dtype=torch.float32, device="cuda")
t = torch.arange(L, device="cuda")
return z_pre, h, bias, poles, residues, t


def _call(engine, z_pre, h, bias, L, poles, residues, t, **kw):
"""
Invoke parallel_iir on the HCL FFT path (long_fir_threshold None).
"""
return engine.parallel_iir(
z_pre,
h,
bias,
L,
poles=poles,
residues=residues,
t=t,
dims=DIMS,
layer_idx=0,
long_fir_threshold=kw.pop("long_fir_threshold", None),
**kw,
)


@pytest.mark.gpu
@pytest.mark.parametrize("L", [2048, 8192])
def test_vk_hcl_on_matches_baseline_fp32(L: int) -> None:
"""
use_hcl_kernel on reproduces the stock parallel_iir HCL output in fp32.
"""
torch.manual_seed(0)
engine = HyenaInferenceEngine(layer_idx=0)
z_pre, h, bias, poles, residues, t = _hcl_inputs(L, torch.float32)

y_off = _call(engine, z_pre, h, bias, L, poles, residues, t)
engine.use_hcl_kernel = True
y_on = _call(engine, z_pre, h, bias, L, poles, residues, t)

assert y_on.shape == y_off.shape == (B, L, D)
assert (y_on - y_off).abs().max().item() < 1e-2


@pytest.mark.gpu
def test_vk_hcl_off_by_default() -> None:
"""
A fresh HyenaInferenceEngine has use_hcl_kernel False -- the stock path.
"""
torch.manual_seed(0)
engine = HyenaInferenceEngine(layer_idx=0)
assert engine.use_hcl_kernel is False

z_pre, h, bias, poles, residues, t = _hcl_inputs(2048, torch.float32)
y_default = _call(engine, z_pre, h, bias, 2048, poles, residues, t)

explicit_off = HyenaInferenceEngine(layer_idx=0, use_hcl_kernel=False)
y_explicit = _call(explicit_off, z_pre, h, bias, 2048, poles, residues, t)

assert (y_default - y_explicit).abs().max().item() == 0.0


@pytest.mark.gpu
def test_vk_hcl_predicate_skips_long_fir() -> None:
"""
The branch matches only long_fir_threshold is None -- a set threshold
routes through the stock depthwise-conv path.
"""
torch.manual_seed(0)
engine = HyenaInferenceEngine(layer_idx=0, use_hcl_kernel=True)
z_pre, h, bias, poles, residues, t = _hcl_inputs(2048, torch.float32)

y = _call(engine, z_pre, h, bias, 2048, poles, residues, t, long_fir_threshold=128)
assert y.shape == (B, 2048, D)
Loading