diff --git a/tests/end_to_end/test_sanitizer.py b/tests/end_to_end/test_sanitizer.py index 41474e8b..2af88e87 100644 --- a/tests/end_to_end/test_sanitizer.py +++ b/tests/end_to_end/test_sanitizer.py @@ -936,6 +936,24 @@ def test_oob_with_fake_tensor(_isolate_virtual_memory): fake_tensor_oob_kernel[(1,)](x, out, N=8) +@triton_viz.trace(client=SymbolicSanitizer()) +@triton.jit +def cast_scalar_kernel(x_ptr, out_ptr, eps: tl.constexpr, N: tl.constexpr): + offs = tl.arange(0, N) + x = tl.load(x_ptr + offs) + dtype = x.dtype + eps_cast = eps.to(dtype) + y = x + eps_cast + tl.store(out_ptr + offs, y) + + +def test_float_no_attr_to(): + """constexpr float .to(dtype) must work in interpreter mode.""" + x = torch.randn(8, device="cpu") + out = torch.empty(8, device="cpu") + cast_scalar_kernel[(1,)](x, out, eps=1e-6, N=8) + + @triton_viz.trace(client=SymbolicSanitizer()) @triton.jit def block_ptr_sum_kernel( @@ -1259,9 +1277,220 @@ def test_reduce_on_batched_dot_result(): ), f"Expected no OOB records, got {len(reduce_dot_sanitizer.records)}" +@triton_viz.trace(client="tracer") +@triton.jit +def cast_truncate_kernel(out_ptr, v: tl.constexpr, N: tl.constexpr): + offs = tl.arange(0, N) + result = v.to(tl.int32) + tl.store(out_ptr + offs, result) + + +def test_constexpr_to_real_cast(): + """constexpr.to(tl.int32) must truncate, not no-op.""" + out = torch.empty(1, dtype=torch.int32) + cast_truncate_kernel[(1,)](out, v=1.9, N=1) + assert out.item() == 1 + + +@triton_viz.trace(client="tracer") +@triton.jit +def cast_bitcast_kernel(out_ptr, v: tl.constexpr, N: tl.constexpr): + offs = tl.arange(0, N) + result = v.to(tl.float32, bitcast=True) + tl.store(out_ptr + offs, result) + + +def test_constexpr_to_bitcast(): + """constexpr.to(tl.float32, bitcast=True) must reinterpret bits.""" + import struct + + out = torch.empty(1, dtype=torch.float32) + cast_bitcast_kernel[(1,)](out, v=42, N=1) + expected = struct.unpack("f", struct.pack("i", 42))[0] + assert out.item() == pytest.approx(expected) + + +@triton_viz.trace(client="tracer") +@triton.jit +def dtype_meta_kernel(out_ptr, DTYPE: tl.constexpr, N: tl.constexpr): + offs = tl.arange(0, N) + x = tl.full((N,), value=1.0, dtype=DTYPE) + tl.store(out_ptr + offs, x) + + +def test_constexpr_dtype_meta_param(): + """dtype passed as constexpr meta-param must work after universal wrapping.""" + out = torch.empty(8, dtype=torch.int32) + dtype_meta_kernel[(1,)](out, DTYPE=tl.int32, N=8) + assert (out == 1).all() + + +@triton_viz.trace(client="tracer") +@triton.jit +def tuple_meta_kernel(out_ptr, SHAPE: tl.constexpr): + offs = tl.arange(0, SHAPE[0]) + tl.store(out_ptr + offs, offs) + + +def test_constexpr_tuple_meta_param(): + """tuple meta-param constexpr must support __getitem__.""" + out = torch.empty(4, dtype=torch.int32) + tuple_meta_kernel[(1,)](out, SHAPE=(4,)) + assert (out == torch.arange(4, dtype=torch.int32)).all() + + +@triton_viz.trace(client="tracer") +@triton.jit +def grid_lambda_kernel(out_ptr, N: tl.constexpr, BLOCK: tl.constexpr): + pid = tl.program_id(0) + offs = pid * BLOCK + tl.arange(0, BLOCK) + mask = offs < N + tl.store(out_ptr + offs, offs, mask=mask) + + +def test_constexpr_grid_lambda(): + """Grid lambda receives raw host values, not tl.constexpr wrappers.""" + N = 256 + out = torch.empty(N, dtype=torch.int32) + + def grid(META): + # Upstream Triton feeds raw Python values to grid lambdas; + # verify isinstance checks work as they would in real Triton. + assert isinstance( + META["BLOCK"], int + ), f"grid lambda should see raw int, got {type(META['BLOCK'])}" + return (triton.cdiv(N, META["BLOCK"]),) + + grid_lambda_kernel[grid](out, N=N, BLOCK=128) + assert (out == torch.arange(N, dtype=torch.int32)).all() + + +@triton_viz.trace(client="tracer") +@triton.jit +def bool_constexpr_kernel(out_ptr, flag: tl.constexpr, N: tl.constexpr): + offs = tl.arange(0, N) + val = flag.to(tl.int1) + tl.store(out_ptr + offs, val) + + +def test_constexpr_bool_cast(): + """bool constexpr .to(tl.int1) must work.""" + out = torch.empty(1, dtype=torch.int32) + bool_constexpr_kernel[(1,)](out, flag=True, N=1) + assert out.item() == 1 + + def test_reinterpret_tensor_wrapper(): """triton.reinterpret() produces a TensorWrapper; sanitizer must handle it.""" N = 64 x = torch.ones(N, dtype=torch.float16, device="cpu") y = torch.empty(N, dtype=torch.float16, device="cpu") copy_kernel[(1,)](triton.reinterpret(x, tl.float16), y, N, BLOCK=64) + + +# ======== constexpr.to() Regression Tests =========== + + +@triton_viz.trace(client="tracer") +@triton.jit +def bitcast_float_to_int_kernel(out_ptr, v: tl.constexpr, N: tl.constexpr): + offs = tl.arange(0, N) + result = v.to(tl.int32, bitcast=True) + tl.store(out_ptr + offs, result) + + +def test_constexpr_bitcast_float_src(): + """bitcast of float constexpr must use float32 source, not float64.""" + import struct + + out = torch.empty(1, dtype=torch.int32) + bitcast_float_to_int_kernel[(1,)](out, v=1.0, N=1) + expected = struct.unpack("i", struct.pack("f", 1.0))[0] # 0x3f800000 + assert out.item() == expected + + +@triton_viz.trace(client="tracer") +@triton.jit +def fp_downcast_kernel( + out_ptr, v: tl.constexpr, rounding: tl.constexpr, N: tl.constexpr +): + offs = tl.arange(0, N) + result = v.to(tl.float16, fp_downcast_rounding=rounding) + tl.store(out_ptr + offs, result) + + +def test_constexpr_fp_downcast_rtz(): + """fp_downcast_rounding='rtz' must truncate toward zero, not round to nearest.""" + # 1.00146484375 (fp32) is the exact midpoint between fp16 values + # 1.0009765625 and 1.001953125. + # rtne rounds UP to 1.001953125; rtz truncates DOWN to 1.0009765625. + v = 1.00146484375 + out_rtz = torch.empty(1, dtype=torch.float16) + fp_downcast_kernel[(1,)](out_rtz, v=v, rounding="rtz", N=1) + out_rtne = torch.empty(1, dtype=torch.float16) + fp_downcast_kernel[(1,)](out_rtne, v=v, rounding="rtne", N=1) + assert float(out_rtz.item()) == pytest.approx(1.0009765625) + assert float(out_rtne.item()) == pytest.approx(1.001953125) + + +@triton_viz.trace(client="tracer") +@triton.jit +def bool_mask_kernel(out_ptr, flag: tl.constexpr, N: tl.constexpr): + offs = tl.arange(0, N) + cond = flag.to(tl.int1) + result = tl.where( + cond, tl.full((N,), 42, dtype=tl.int32), tl.full((N,), -1, dtype=tl.int32) + ) + tl.store(out_ptr + offs, result) + + +def test_constexpr_bool_as_mask(): + """bool constexpr .to(tl.int1) must work as a predicate in tl.where.""" + out = torch.empty(4, dtype=torch.int32) + bool_mask_kernel[(1,)](out, flag=True, N=4) + assert (out == 42).all() + bool_mask_kernel[(1,)](out, flag=False, N=4) + assert (out == -1).all() + + +def test_constexpr_bitcast_mismatched_size_raises(): + """bitcast between types of different sizes must raise ValueError.""" + from triton_viz.core.patch import _constexpr_to + import types + + mock_self = types.SimpleNamespace(value=1.0) + # float(1.0) -> float32 (32 bits), tl.int64 (64 bits) -> mismatch + with pytest.raises(ValueError, match="Cannot bitcast"): + _constexpr_to(mock_self, tl.int64, bitcast=True) + + +def test_constexpr_to_unsupported_dtype_raises(): + """Casting constexpr to an exotic dtype without numpy equivalent must raise, not KeyError.""" + from triton_viz.core.patch import _constexpr_to + import types + + mock_self = types.SimpleNamespace(value=1.0) + for fp8_name in ["float8e4nv", "float8e5"]: + dt = getattr(tl, fp8_name, None) + if dt is None: + continue + with pytest.raises(TypeError, match="not supported"): + _constexpr_to(mock_self, dt) + + +# ======== Unsigned Division Regression Test =========== + + +@triton_viz.trace(client="tracer") +@triton.jit +def unsigned_div_kernel(out_ptr, DENOM: tl.constexpr): + numerator = tl.full((1,), 8, tl.uint32) + denom = DENOM.to(tl.uint32) + out = numerator // denom + tl.store(out_ptr + tl.arange(0, 1), out) + + +def test_constexpr_unsigned_div(): + out = torch.empty(1, dtype=torch.int32) + unsigned_div_kernel[(1,)](out, DENOM=1) + assert out.item() == 8 diff --git a/tests/unit/test_constexpr_to.py b/tests/unit/test_constexpr_to.py new file mode 100644 index 00000000..8f013f4f --- /dev/null +++ b/tests/unit/test_constexpr_to.py @@ -0,0 +1,64 @@ +import types + +import pytest +import numpy as np +import triton.language as tl + +from triton_viz.core.patch import _constexpr_to, _src_np_dtype + + +@pytest.mark.parametrize( + "value,triton_dtype,expected_np,expected_val", + [ + (1.0, tl.float64, np.float64, 1.0), + (1, tl.uint8, np.uint8, 1), + (-7, tl.int16, np.int16, -7), + (1, tl.uint32, np.uint32, 1), + # Narrowing / signedness-change casts — must wrap, not reject + (300, tl.uint8, np.uint8, 300 % 256), # 44 + (-1, tl.uint32, np.uint32, 2**32 - 1), # 4294967295 + (2**31, tl.int32, np.int32, -(2**31)), # signed wrap + (255, tl.int8, np.int8, -1), # unsigned → signed wrap + ], +) +def test_constexpr_to_preserves_dtype(value, triton_dtype, expected_np, expected_val): + mock_self = types.SimpleNamespace(value=value) + ret = _constexpr_to(mock_self, triton_dtype) + assert isinstance(ret, tl.core.tensor) + assert ret.dtype == triton_dtype + assert ret.handle.data.dtype == np.dtype(expected_np) + assert ret.handle.data.item() == expected_val + + +def test_constexpr_bitcast_preserves_dtype(): + mock_self = types.SimpleNamespace(value=42) + ret = _constexpr_to(mock_self, tl.float32, bitcast=True) + assert isinstance(ret, tl.core.tensor) + assert ret.dtype == tl.float32 + + +@pytest.mark.parametrize("bad_mode", ["RTZ", "foo", ""]) +def test_constexpr_to_invalid_rounding_mode_raises(bad_mode): + """Invalid fp_downcast_rounding values must raise ValueError, not silently fall back.""" + mock_self = types.SimpleNamespace(value=1.0) + with pytest.raises(ValueError, match="fp_downcast_rounding must be one of"): + _constexpr_to(mock_self, tl.float16, fp_downcast_rounding=bad_mode) + + +def test_constexpr_to_overflow_raises(): + """Integer literals outside [-2**63, 2**64) must raise OverflowError.""" + mock_self = types.SimpleNamespace(value=-(2**63) - 1) + with pytest.raises(OverflowError, match="outside the representable range"): + _constexpr_to(mock_self, tl.int64) + + +def test_constexpr_bitcast_bool_to_int8_raises(): + """Bitcast bool(int1) -> int8 must raise: primitive_bitwidth 1 != 8.""" + mock_self = types.SimpleNamespace(value=True) + with pytest.raises(ValueError, match="Cannot bitcast"): + _constexpr_to(mock_self, tl.int8, bitcast=True) + + +def test_src_np_dtype_max_uint64(): + """Values in [2**63, 2**64) must map to uint64, not overflow.""" + assert _src_np_dtype(2**64 - 1) == np.dtype(np.uint64) diff --git a/tests/unit/test_patch_scope.py b/tests/unit/test_patch_scope.py index 59fd418c..1403cb3e 100644 --- a/tests/unit/test_patch_scope.py +++ b/tests/unit/test_patch_scope.py @@ -3,7 +3,7 @@ import triton.language as tl from triton_viz.core import patch as patch_mod -from triton_viz.core.patch import _triton_snapshot_scope +from triton_viz.core.patch import _triton_snapshot_scope, patch_lang, unpatch_lang def _dummy_kernel(): @@ -72,3 +72,29 @@ def _getmembers(obj): scope.restore() assert getattr(descriptor, attr) is original + + +def test_constexpr_patch_lifecycle(): + """patch_lang must add .to/__getattr__ to constexpr; unpatch_lang must restore.""" + # Snapshot before-state + had_to = hasattr(tl.constexpr, "to") + had_getattr = hasattr(tl.constexpr, "__getattr__") + orig_to = getattr(tl.constexpr, "to", None) + orig_getattr = getattr(tl.constexpr, "__getattr__", None) + + patch_lang(_dummy_kernel, "triton") + try: + assert hasattr(tl.constexpr, "to") + assert callable(tl.constexpr.to) + assert hasattr(tl.constexpr, "__getattr__") + assert callable(tl.constexpr.__getattr__) + finally: + unpatch_lang("triton") + + # Verify exact before-state is restored + assert hasattr(tl.constexpr, "to") == had_to + assert hasattr(tl.constexpr, "__getattr__") == had_getattr + if had_to: + assert getattr(tl.constexpr, "to") is orig_to + if had_getattr: + assert getattr(tl.constexpr, "__getattr__") is orig_getattr diff --git a/tests/unit/test_symbolic_bool.py b/tests/unit/test_symbolic_bool.py new file mode 100644 index 00000000..551cd73a --- /dev/null +++ b/tests/unit/test_symbolic_bool.py @@ -0,0 +1,13 @@ +import triton.language as tl + +from triton_viz.clients.symbolic_engine import SymbolicExpr + + +def test_bool_inferred_as_int1(): + assert SymbolicExpr.from_value(True).dtype is tl.int1 + assert SymbolicExpr.from_value(False).dtype is tl.int1 + + +def test_constexpr_bool_inferred_as_int1(): + assert SymbolicExpr.from_value(tl.constexpr(True)).dtype is tl.int1 + assert SymbolicExpr.from_value(tl.constexpr(False)).dtype is tl.int1 diff --git a/triton_viz/clients/symbolic_engine.py b/triton_viz/clients/symbolic_engine.py index c392f4f8..6ad631d7 100644 --- a/triton_viz/clients/symbolic_engine.py +++ b/triton_viz/clients/symbolic_engine.py @@ -429,7 +429,7 @@ def _node_label_core(self) -> str: tl.float32, tl.float64, ) - builtin_scala_types: ClassVar[tuple[type, ...]] = (int, float) + builtin_scala_types: ClassVar[tuple[type, ...]] = (bool, int, float) tuple_types: ClassVar[tuple[type, ...]] = (tl.core.tuple, tuple, list) @staticmethod @@ -458,8 +458,12 @@ def _infer_literal_dtype(var: Any) -> tl.core.dtype | tl.pointer_type: f"All elements in the tuple must have the same dtype, but found {first_dtype} and {SymbolicExpr.from_value(v).dtype}" ) return first_dtype - if isinstance(var, SymbolicExpr.builtin_scala_types): - return tl.int32 if isinstance(var, int) else tl.float32 + if isinstance(var, bool): + return tl.int1 + if isinstance(var, int): + return tl.int32 + if isinstance(var, float): + return tl.float32 raise ValueError(f"Unsupported type: {type(var)}") # Stored on the class and may be accessed through either the class or an instance; @@ -475,6 +479,8 @@ def set_loop_ctx_provider(cls, fn: Callable[..., LoopContext | None]) -> None: @classmethod def from_value(cls, var: Any) -> SymbolicExpr: """Create a SymbolicExpr from a Python value.""" + if isinstance(var, tl.constexpr): # unwrap constexpr to raw value + var = var.value if isinstance(var, tl.core.tensor): # if a triton tensor var = var.handle # get its handle diff --git a/triton_viz/core/patch.py b/triton_viz/core/patch.py index 17239045..69d4a118 100644 --- a/triton_viz/core/patch.py +++ b/triton_viz/core/patch.py @@ -18,9 +18,12 @@ RawStore, ) import inspect +import numpy as np from triton.runtime.interpreter import ( GridExecutor, + TensorHandle, _implicit_cvt, + _get_np_dtype, interpreter_builder, ) from triton.runtime.interpreter import _patch_lang as triton_patch_lang @@ -36,6 +39,170 @@ _MISSING = object() +_F32_MIN_NORMAL = 2**-126 +_F32_MAX = (2 - 2**-23) * 2**127 + +_RTZ_STORAGE: dict[np.dtype, type[np.unsignedinteger]] = { + np.dtype(np.float16): np.uint16, + np.dtype(np.float32): np.uint32, + np.dtype(np.float64): np.uint64, +} + +try: + import ml_dtypes as _ml_dtypes + + _RTZ_STORAGE[np.dtype(_ml_dtypes.bfloat16)] = np.uint16 +except ImportError: + _ml_dtypes = None # type: ignore[assignment] + + +def _src_triton_dtype(value: object) -> tl.core.dtype: + """Infer the Triton dtype for a Python literal, mirroring upstream to_tensor.""" + if isinstance(value, bool): + return tl.int1 + if isinstance(value, int): + if -(2**31) <= value < 2**31: + return tl.int32 + if 2**31 <= value < 2**32: + return tl.uint32 + if -(2**63) <= value < 2**63: + return tl.int64 + if 0 <= value < 2**64: + return tl.uint64 + raise OverflowError( + f"Integer literal {value} is outside the representable range " + f"[-2**63, 2**64) for Triton integer types" + ) + assert isinstance(value, float) + abs_x: float = abs(value) + if ( + abs_x == 0.0 + or abs_x != abs_x + or abs_x == float("inf") + or (_F32_MIN_NORMAL <= abs_x <= _F32_MAX) + ): + return tl.float32 + return tl.float64 + + +def _src_np_dtype(value: object) -> np.dtype: + """Infer the numpy dtype for a Python literal, matching Triton's to_tensor logic.""" + return np.dtype(_get_np_dtype(_src_triton_dtype(value))) + + +def _cast_np_dtype(triton_dtype: tl.core.dtype) -> np.dtype: + """Numpy dtype for semantic cast. Raises for exotic types without numpy equivalents.""" + np_dt = _get_np_dtype(triton_dtype) + # _get_np_dtype maps bf16->uint16 and fp8->uint8 (storage types). + # For cast we need semantic types; bf16 is available via ml_dtypes. + if triton_dtype == tl.bfloat16: + if _ml_dtypes is None: + raise TypeError("constexpr.to(bfloat16) requires ml_dtypes to be installed") + return np.dtype(_ml_dtypes.bfloat16) + if np_dt in (np.uint8, np.uint16) and triton_dtype not in ( + tl.uint8, + tl.uint16, + tl.int1, + ): + raise TypeError( + f"constexpr.to({triton_dtype}) is not supported in interpreter mode " + f"(no numpy-compatible semantic type for {triton_dtype})" + ) + return np.dtype(np_dt) + + +def _fp_downcast_rtz(value: float, dst_np_dtype: np.dtype) -> np.generic: + """Round-toward-zero float downcast by correcting rtne when it rounds away from zero.""" + rtne_result = dst_np_dtype.type(value) + if abs(float(rtne_result)) > abs(value): + # rtne rounded away from zero -- subtract 1 ULP + storage = _RTZ_STORAGE[dst_np_dtype] + bits = np.frombuffer( + np.array([rtne_result], dtype=dst_np_dtype).tobytes(), + dtype=storage, + )[0] + bits -= np.array(1, dtype=storage) + return np.frombuffer( + np.array([bits], dtype=storage).tobytes(), + dtype=dst_np_dtype, + )[0] + return rtne_result + + +def _typed_scalar_tensor(result: np.generic, dtype: tl.core.dtype) -> tl.core.tensor: + """Wrap a numpy scalar as a Triton interpreter tensor with an exact dtype.""" + storage_np = _get_np_dtype(dtype) + arr = np.array([result]) + if arr.dtype != storage_np: + # Semantic/storage dtype differ (e.g. ml_dtypes.bfloat16 -> uint16) + arr = arr.view(storage_np) + return tl.core.tensor(TensorHandle(arr, dtype), dtype) + + +def _constexpr_to(self, dtype, fp_downcast_rounding=None, bitcast=False): + """Interpreter-mode implementation of constexpr.to(dtype). + + Computes the cast with numpy directly (not via interpreter_semantic) + to avoid re-entering patched tl ops when a client like the sanitizer + is active. + """ + value = self.value if not isinstance(self.value, tl.constexpr) else self.value.value + dtype = dtype.value if isinstance(dtype, tl.constexpr) else dtype + fp_downcast_rounding = ( + fp_downcast_rounding.value + if isinstance(fp_downcast_rounding, tl.constexpr) + else fp_downcast_rounding + ) + bitcast_val = bitcast.value if isinstance(bitcast, tl.constexpr) else bitcast + + if bitcast_val: + src_triton = _src_triton_dtype(value) + src_bits = src_triton.primitive_bitwidth + dst_bits = dtype.primitive_bitwidth + if src_bits != dst_bits: + raise ValueError( + f"Cannot bitcast data-type of size {src_bits} to " + f"data-type of size {dst_bits}" + ) + src_np = _src_np_dtype(value) + dst_np = np.dtype(_get_np_dtype(dtype)) + raw = np.array([value], dtype=src_np).tobytes() + result = np.frombuffer(raw, dtype=dst_np)[0] + else: + dst_np = _cast_np_dtype(dtype) + src_np = _src_np_dtype(value) + + if fp_downcast_rounding is not None: + _VALID_ROUNDING_MODES = ("rtne", "rtz") + if fp_downcast_rounding not in _VALID_ROUNDING_MODES: + raise ValueError( + f"fp_downcast_rounding must be one of {_VALID_ROUNDING_MODES}, " + f"got {fp_downcast_rounding!r}" + ) + src_is_float = src_np.kind == "f" + dst_is_float = dst_np.kind == "f" or ( + _ml_dtypes is not None and dst_np == np.dtype(_ml_dtypes.bfloat16) + ) + if not ( + dst_is_float and src_is_float and dst_np.itemsize < src_np.itemsize + ): + raise ValueError( + f"fp_downcast_rounding is only valid for float-to-smaller-float casts, " + f"got {src_np} -> {dst_np}" + ) + + if fp_downcast_rounding == "rtz": + result = _fp_downcast_rtz(value, dst_np) + else: + result = np.array([value], dtype=src_np).astype(dst_np)[0] + + return _typed_scalar_tensor(result, dtype) + + +def _constexpr_getattr(self, name): + """Proxy attribute access to the wrapped value for interpreter mode.""" + return getattr(self.value, name) + class _LangPatchScope: """Tracks patched attributes so they can be restored.""" @@ -57,6 +224,27 @@ def restore(self) -> None: setattr(obj, name, original) +def _patch_constexpr(scope: _LangPatchScope) -> None: + """Patch tl.constexpr with .to() and __getattr__ for interpreter mode.""" + if not hasattr(tl.constexpr, "to"): + scope.set_attr(tl.constexpr, "to", _constexpr_to) + if not hasattr(tl.constexpr, "__getattr__"): + scope.set_attr(tl.constexpr, "__getattr__", _constexpr_getattr) + + +def _normalize_constexpr_arg(arg): + """Wrap a kernel argument in tl.constexpr if it isn't already one. + + None is left as-is because Triton APIs accept ``constexpr | None`` + (e.g. ``reduce(..., dtype=None)``). + """ + if isinstance(arg, tl.constexpr): + return arg + if arg is None: + return None + return tl.constexpr(arg) + + _LANG_PATCH_SCOPES: dict[str, list[Any]] = {"triton": [], "nki": [], "nki_beta2": []} @@ -325,6 +513,7 @@ def patch_lang(fn, backend, client_manager=None): if backend == "triton": scope = _triton_snapshot_scope(fn) triton_patch_lang(fn) + _patch_constexpr(scope) elif backend == "nki": from triton_viz.core.nki import nki_patch_lang @@ -445,18 +634,24 @@ def _grid_executor_call(self, *args_dev, backend=None, **kwargs): # Prepare call arguments args = inspect.getcallargs(self.fn, *args_hst, **kwargs_hst) + # grid_args keeps constexprs as raw host values (matching upstream Triton) + # so that grid lambdas see plain Python types, not tl.constexpr wrappers. + grid_args = { + name: (arg if name in self.constexprs else _implicit_cvt(arg)) + for name, arg in args.items() + } + grid_args.pop("self", None) call_args = {} for name, arg in args.items(): if name in self.constexprs: - call_args[name] = arg - ret = arg + ret = _normalize_constexpr_arg(arg) else: ret = _implicit_cvt(arg) client_manager.arg_callback(name, arg, ret) call_args[name] = ret call_args.pop("self", None) # Iterate through grid - grid = self.grid(call_args) if callable(self.grid) else self.grid + grid = self.grid(grid_args) if callable(self.grid) else self.grid assert len(grid) <= 3 grid = grid + (1,) * (3 - len(grid))