Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 229 additions & 0 deletions tests/end_to_end/test_sanitizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -936,6 +936,24 @@ def test_oob_with_fake_tensor(_isolate_virtual_memory):
fake_tensor_oob_kernel[(1,)](x, out, N=8)


@triton_viz.trace(client=SymbolicSanitizer())
@triton.jit
def cast_scalar_kernel(x_ptr, out_ptr, eps: tl.constexpr, N: tl.constexpr):
offs = tl.arange(0, N)
x = tl.load(x_ptr + offs)
dtype = x.dtype
eps_cast = eps.to(dtype)
y = x + eps_cast
tl.store(out_ptr + offs, y)


def test_float_no_attr_to():
"""constexpr float .to(dtype) must work in interpreter mode."""
x = torch.randn(8, device="cpu")
out = torch.empty(8, device="cpu")
cast_scalar_kernel[(1,)](x, out, eps=1e-6, N=8)


@triton_viz.trace(client=SymbolicSanitizer())
@triton.jit
def block_ptr_sum_kernel(
Expand Down Expand Up @@ -1259,9 +1277,220 @@ def test_reduce_on_batched_dot_result():
), f"Expected no OOB records, got {len(reduce_dot_sanitizer.records)}"


@triton_viz.trace(client="tracer")
@triton.jit
def cast_truncate_kernel(out_ptr, v: tl.constexpr, N: tl.constexpr):
offs = tl.arange(0, N)
result = v.to(tl.int32)
tl.store(out_ptr + offs, result)


def test_constexpr_to_real_cast():
"""constexpr.to(tl.int32) must truncate, not no-op."""
out = torch.empty(1, dtype=torch.int32)
cast_truncate_kernel[(1,)](out, v=1.9, N=1)
assert out.item() == 1


@triton_viz.trace(client="tracer")
@triton.jit
def cast_bitcast_kernel(out_ptr, v: tl.constexpr, N: tl.constexpr):
offs = tl.arange(0, N)
result = v.to(tl.float32, bitcast=True)
tl.store(out_ptr + offs, result)


def test_constexpr_to_bitcast():
"""constexpr.to(tl.float32, bitcast=True) must reinterpret bits."""
import struct

out = torch.empty(1, dtype=torch.float32)
cast_bitcast_kernel[(1,)](out, v=42, N=1)
expected = struct.unpack("f", struct.pack("i", 42))[0]
assert out.item() == pytest.approx(expected)


@triton_viz.trace(client="tracer")
@triton.jit
def dtype_meta_kernel(out_ptr, DTYPE: tl.constexpr, N: tl.constexpr):
offs = tl.arange(0, N)
x = tl.full((N,), value=1.0, dtype=DTYPE)
tl.store(out_ptr + offs, x)


def test_constexpr_dtype_meta_param():
"""dtype passed as constexpr meta-param must work after universal wrapping."""
out = torch.empty(8, dtype=torch.int32)
dtype_meta_kernel[(1,)](out, DTYPE=tl.int32, N=8)
assert (out == 1).all()


@triton_viz.trace(client="tracer")
@triton.jit
def tuple_meta_kernel(out_ptr, SHAPE: tl.constexpr):
offs = tl.arange(0, SHAPE[0])
tl.store(out_ptr + offs, offs)


def test_constexpr_tuple_meta_param():
"""tuple meta-param constexpr must support __getitem__."""
out = torch.empty(4, dtype=torch.int32)
tuple_meta_kernel[(1,)](out, SHAPE=(4,))
assert (out == torch.arange(4, dtype=torch.int32)).all()


@triton_viz.trace(client="tracer")
@triton.jit
def grid_lambda_kernel(out_ptr, N: tl.constexpr, BLOCK: tl.constexpr):
pid = tl.program_id(0)
offs = pid * BLOCK + tl.arange(0, BLOCK)
mask = offs < N
tl.store(out_ptr + offs, offs, mask=mask)


def test_constexpr_grid_lambda():
"""Grid lambda receives raw host values, not tl.constexpr wrappers."""
N = 256
out = torch.empty(N, dtype=torch.int32)

def grid(META):
# Upstream Triton feeds raw Python values to grid lambdas;
# verify isinstance checks work as they would in real Triton.
assert isinstance(
META["BLOCK"], int
), f"grid lambda should see raw int, got {type(META['BLOCK'])}"
return (triton.cdiv(N, META["BLOCK"]),)

grid_lambda_kernel[grid](out, N=N, BLOCK=128)
assert (out == torch.arange(N, dtype=torch.int32)).all()


@triton_viz.trace(client="tracer")
@triton.jit
def bool_constexpr_kernel(out_ptr, flag: tl.constexpr, N: tl.constexpr):
offs = tl.arange(0, N)
val = flag.to(tl.int1)
tl.store(out_ptr + offs, val)


def test_constexpr_bool_cast():
"""bool constexpr .to(tl.int1) must work."""
out = torch.empty(1, dtype=torch.int32)
bool_constexpr_kernel[(1,)](out, flag=True, N=1)
assert out.item() == 1


def test_reinterpret_tensor_wrapper():
"""triton.reinterpret() produces a TensorWrapper; sanitizer must handle it."""
N = 64
x = torch.ones(N, dtype=torch.float16, device="cpu")
y = torch.empty(N, dtype=torch.float16, device="cpu")
copy_kernel[(1,)](triton.reinterpret(x, tl.float16), y, N, BLOCK=64)


# ======== constexpr.to() Regression Tests ===========


@triton_viz.trace(client="tracer")
@triton.jit
def bitcast_float_to_int_kernel(out_ptr, v: tl.constexpr, N: tl.constexpr):
offs = tl.arange(0, N)
result = v.to(tl.int32, bitcast=True)
tl.store(out_ptr + offs, result)


def test_constexpr_bitcast_float_src():
"""bitcast of float constexpr must use float32 source, not float64."""
import struct

out = torch.empty(1, dtype=torch.int32)
bitcast_float_to_int_kernel[(1,)](out, v=1.0, N=1)
expected = struct.unpack("i", struct.pack("f", 1.0))[0] # 0x3f800000
assert out.item() == expected


@triton_viz.trace(client="tracer")
@triton.jit
def fp_downcast_kernel(
out_ptr, v: tl.constexpr, rounding: tl.constexpr, N: tl.constexpr
):
offs = tl.arange(0, N)
result = v.to(tl.float16, fp_downcast_rounding=rounding)
tl.store(out_ptr + offs, result)


def test_constexpr_fp_downcast_rtz():
"""fp_downcast_rounding='rtz' must truncate toward zero, not round to nearest."""
# 1.00146484375 (fp32) is the exact midpoint between fp16 values
# 1.0009765625 and 1.001953125.
# rtne rounds UP to 1.001953125; rtz truncates DOWN to 1.0009765625.
v = 1.00146484375
out_rtz = torch.empty(1, dtype=torch.float16)
fp_downcast_kernel[(1,)](out_rtz, v=v, rounding="rtz", N=1)
out_rtne = torch.empty(1, dtype=torch.float16)
fp_downcast_kernel[(1,)](out_rtne, v=v, rounding="rtne", N=1)
assert float(out_rtz.item()) == pytest.approx(1.0009765625)
assert float(out_rtne.item()) == pytest.approx(1.001953125)


@triton_viz.trace(client="tracer")
@triton.jit
def bool_mask_kernel(out_ptr, flag: tl.constexpr, N: tl.constexpr):
offs = tl.arange(0, N)
cond = flag.to(tl.int1)
result = tl.where(
cond, tl.full((N,), 42, dtype=tl.int32), tl.full((N,), -1, dtype=tl.int32)
)
tl.store(out_ptr + offs, result)


def test_constexpr_bool_as_mask():
"""bool constexpr .to(tl.int1) must work as a predicate in tl.where."""
out = torch.empty(4, dtype=torch.int32)
bool_mask_kernel[(1,)](out, flag=True, N=4)
assert (out == 42).all()
bool_mask_kernel[(1,)](out, flag=False, N=4)
assert (out == -1).all()


def test_constexpr_bitcast_mismatched_size_raises():
"""bitcast between types of different sizes must raise ValueError."""
from triton_viz.core.patch import _constexpr_to
import types

mock_self = types.SimpleNamespace(value=1.0)
# float(1.0) -> float32 (32 bits), tl.int64 (64 bits) -> mismatch
with pytest.raises(ValueError, match="Cannot bitcast"):
_constexpr_to(mock_self, tl.int64, bitcast=True)


def test_constexpr_to_unsupported_dtype_raises():
"""Casting constexpr to an exotic dtype without numpy equivalent must raise, not KeyError."""
from triton_viz.core.patch import _constexpr_to
import types

mock_self = types.SimpleNamespace(value=1.0)
for fp8_name in ["float8e4nv", "float8e5"]:
dt = getattr(tl, fp8_name, None)
if dt is None:
continue
with pytest.raises(TypeError, match="not supported"):
_constexpr_to(mock_self, dt)


# ======== Unsigned Division Regression Test ===========


@triton_viz.trace(client="tracer")
@triton.jit
def unsigned_div_kernel(out_ptr, DENOM: tl.constexpr):
numerator = tl.full((1,), 8, tl.uint32)
denom = DENOM.to(tl.uint32)
out = numerator // denom
tl.store(out_ptr + tl.arange(0, 1), out)


def test_constexpr_unsigned_div():
out = torch.empty(1, dtype=torch.int32)
unsigned_div_kernel[(1,)](out, DENOM=1)
assert out.item() == 8
64 changes: 64 additions & 0 deletions tests/unit/test_constexpr_to.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import types

import pytest
import numpy as np
import triton.language as tl

from triton_viz.core.patch import _constexpr_to, _src_np_dtype


@pytest.mark.parametrize(
"value,triton_dtype,expected_np,expected_val",
[
(1.0, tl.float64, np.float64, 1.0),
(1, tl.uint8, np.uint8, 1),
(-7, tl.int16, np.int16, -7),
(1, tl.uint32, np.uint32, 1),
# Narrowing / signedness-change casts — must wrap, not reject
(300, tl.uint8, np.uint8, 300 % 256), # 44
(-1, tl.uint32, np.uint32, 2**32 - 1), # 4294967295
(2**31, tl.int32, np.int32, -(2**31)), # signed wrap
(255, tl.int8, np.int8, -1), # unsigned → signed wrap
],
)
def test_constexpr_to_preserves_dtype(value, triton_dtype, expected_np, expected_val):
mock_self = types.SimpleNamespace(value=value)
ret = _constexpr_to(mock_self, triton_dtype)
assert isinstance(ret, tl.core.tensor)
assert ret.dtype == triton_dtype
assert ret.handle.data.dtype == np.dtype(expected_np)
assert ret.handle.data.item() == expected_val


def test_constexpr_bitcast_preserves_dtype():
mock_self = types.SimpleNamespace(value=42)
ret = _constexpr_to(mock_self, tl.float32, bitcast=True)
assert isinstance(ret, tl.core.tensor)
assert ret.dtype == tl.float32


@pytest.mark.parametrize("bad_mode", ["RTZ", "foo", ""])
def test_constexpr_to_invalid_rounding_mode_raises(bad_mode):
"""Invalid fp_downcast_rounding values must raise ValueError, not silently fall back."""
mock_self = types.SimpleNamespace(value=1.0)
with pytest.raises(ValueError, match="fp_downcast_rounding must be one of"):
_constexpr_to(mock_self, tl.float16, fp_downcast_rounding=bad_mode)


def test_constexpr_to_overflow_raises():
"""Integer literals outside [-2**63, 2**64) must raise OverflowError."""
mock_self = types.SimpleNamespace(value=-(2**63) - 1)
with pytest.raises(OverflowError, match="outside the representable range"):
_constexpr_to(mock_self, tl.int64)


def test_constexpr_bitcast_bool_to_int8_raises():
"""Bitcast bool(int1) -> int8 must raise: primitive_bitwidth 1 != 8."""
mock_self = types.SimpleNamespace(value=True)
with pytest.raises(ValueError, match="Cannot bitcast"):
_constexpr_to(mock_self, tl.int8, bitcast=True)


def test_src_np_dtype_max_uint64():
"""Values in [2**63, 2**64) must map to uint64, not overflow."""
assert _src_np_dtype(2**64 - 1) == np.dtype(np.uint64)
28 changes: 27 additions & 1 deletion tests/unit/test_patch_scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import triton.language as tl

from triton_viz.core import patch as patch_mod
from triton_viz.core.patch import _triton_snapshot_scope
from triton_viz.core.patch import _triton_snapshot_scope, patch_lang, unpatch_lang


def _dummy_kernel():
Expand Down Expand Up @@ -72,3 +72,29 @@ def _getmembers(obj):
scope.restore()

assert getattr(descriptor, attr) is original


def test_constexpr_patch_lifecycle():
"""patch_lang must add .to/__getattr__ to constexpr; unpatch_lang must restore."""
# Snapshot before-state
had_to = hasattr(tl.constexpr, "to")
had_getattr = hasattr(tl.constexpr, "__getattr__")
orig_to = getattr(tl.constexpr, "to", None)
orig_getattr = getattr(tl.constexpr, "__getattr__", None)

patch_lang(_dummy_kernel, "triton")
try:
assert hasattr(tl.constexpr, "to")
assert callable(tl.constexpr.to)
assert hasattr(tl.constexpr, "__getattr__")
assert callable(tl.constexpr.__getattr__)
finally:
unpatch_lang("triton")

# Verify exact before-state is restored
assert hasattr(tl.constexpr, "to") == had_to
assert hasattr(tl.constexpr, "__getattr__") == had_getattr
if had_to:
assert getattr(tl.constexpr, "to") is orig_to
if had_getattr:
assert getattr(tl.constexpr, "__getattr__") is orig_getattr
13 changes: 13 additions & 0 deletions tests/unit/test_symbolic_bool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import triton.language as tl

from triton_viz.clients.symbolic_engine import SymbolicExpr


def test_bool_inferred_as_int1():
assert SymbolicExpr.from_value(True).dtype is tl.int1
assert SymbolicExpr.from_value(False).dtype is tl.int1


def test_constexpr_bool_inferred_as_int1():
assert SymbolicExpr.from_value(tl.constexpr(True)).dtype is tl.int1
assert SymbolicExpr.from_value(tl.constexpr(False)).dtype is tl.int1
12 changes: 9 additions & 3 deletions triton_viz/clients/symbolic_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ def _node_label_core(self) -> str:
tl.float32,
tl.float64,
)
builtin_scala_types: ClassVar[tuple[type, ...]] = (int, float)
builtin_scala_types: ClassVar[tuple[type, ...]] = (bool, int, float)
tuple_types: ClassVar[tuple[type, ...]] = (tl.core.tuple, tuple, list)

@staticmethod
Expand Down Expand Up @@ -458,8 +458,12 @@ def _infer_literal_dtype(var: Any) -> tl.core.dtype | tl.pointer_type:
f"All elements in the tuple must have the same dtype, but found {first_dtype} and {SymbolicExpr.from_value(v).dtype}"
)
return first_dtype
if isinstance(var, SymbolicExpr.builtin_scala_types):
return tl.int32 if isinstance(var, int) else tl.float32
if isinstance(var, bool):
return tl.int1
if isinstance(var, int):
return tl.int32
if isinstance(var, float):
return tl.float32
raise ValueError(f"Unsupported type: {type(var)}")

# Stored on the class and may be accessed through either the class or an instance;
Expand All @@ -475,6 +479,8 @@ def set_loop_ctx_provider(cls, fn: Callable[..., LoopContext | None]) -> None:
@classmethod
def from_value(cls, var: Any) -> SymbolicExpr:
"""Create a SymbolicExpr from a Python value."""
if isinstance(var, tl.constexpr): # unwrap constexpr to raw value
var = var.value
if isinstance(var, tl.core.tensor): # if a triton tensor
var = var.handle # get its handle

Expand Down
Loading
Loading