From 0524cafe1ca82a4d92b367c0055613561c3f0b5c Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Wed, 4 Mar 2026 02:38:46 +0000 Subject: [PATCH 1/7] [FIX] Support constexpr .to() in interpreter mode In compiled Triton, constexpr.to(dtype) works via the compiler pipeline, but in interpreter mode constexpr has no .to() method, and raw Python scalars passed as constexprs are not wrapped. Three coupled fixes: - Monkey-patch tl.constexpr.to() for interpreter mode - Wrap scalar constexprs in tl.constexpr in _grid_executor_call - Unwrap tl.constexpr in SymbolicExpr.from_value to prevent type errors --- tests/end_to_end/test_sanitizer.py | 26 ++++++++++++++++++++++---- triton_viz/clients/symbolic_engine.py | 2 ++ triton_viz/core/patch.py | 16 ++++++++++++++-- 3 files changed, 38 insertions(+), 6 deletions(-) diff --git a/tests/end_to_end/test_sanitizer.py b/tests/end_to_end/test_sanitizer.py index 24e2e0dd..421f4370 100644 --- a/tests/end_to_end/test_sanitizer.py +++ b/tests/end_to_end/test_sanitizer.py @@ -301,7 +301,7 @@ def test_loop_deferred_checks_simplify(): # Dedicated sanitizer for nested loop regression test -nested_loop_checker = SymbolicSanitizer(abort_on_error=False) +nested_loop_checker = SymbolicSanitizer() @triton_viz.trace(client=nested_loop_checker) @@ -329,7 +329,7 @@ def test_nested_loop_no_false_positive(): # Create a dedicated sanitizer for line number tests -line_number_checker: SymbolicSanitizer = SymbolicSanitizer(abort_on_error=False) +line_number_checker: SymbolicSanitizer = SymbolicSanitizer() @triton_viz.trace(client=line_number_checker) @@ -452,7 +452,7 @@ def test_gemm_oob_call_stack(): # ======== Block Tensor (Block Pointer) Tests =========== -block_sanitizer = SymbolicSanitizer(abort_on_error=False) +block_sanitizer = SymbolicSanitizer() @triton_viz.trace(client=block_sanitizer) @@ -813,7 +813,7 @@ def oob_kernel(ptr, BLOCK: tl.constexpr): # ======== Reduce + Broadcast Tests =========== -reduce_broadcast_sanitizer = SymbolicSanitizer(abort_on_error=False) +reduce_broadcast_sanitizer = SymbolicSanitizer() @triton_viz.trace(client=reduce_broadcast_sanitizer) @@ -873,3 +873,21 @@ def test_oob_with_fake_tensor(): fake_tensor_oob_kernel[(1,)](x, out, N=8) finally: config.virtual_memory = old_virtual_memory + + +@triton_viz.trace(client=SymbolicSanitizer()) +@triton.jit +def cast_scalar_kernel(x_ptr, out_ptr, eps: tl.constexpr, N: tl.constexpr): + offs = tl.arange(0, N) + x = tl.load(x_ptr + offs) + dtype = x.dtype + eps_cast = eps.to(dtype) + y = x + eps_cast + tl.store(out_ptr + offs, y) + + +def test_float_no_attr_to(): + """constexpr float .to(dtype) must work in interpreter mode.""" + x = torch.randn(8, device="cpu") + out = torch.empty(8, device="cpu") + cast_scalar_kernel[(1,)](x, out, eps=1e-6, N=8) diff --git a/triton_viz/clients/symbolic_engine.py b/triton_viz/clients/symbolic_engine.py index d279987e..ade5454c 100644 --- a/triton_viz/clients/symbolic_engine.py +++ b/triton_viz/clients/symbolic_engine.py @@ -467,6 +467,8 @@ def set_loop_ctx_provider(cls, fn: Callable[..., LoopContext | None]) -> None: @classmethod def from_value(cls, var: Any) -> SymbolicExpr: """Create a SymbolicExpr from a Python value.""" + if isinstance(var, tl.constexpr): # unwrap constexpr to raw value + var = var.value if isinstance(var, tl.core.tensor): # if a triton tensor var = var.handle # get its handle diff --git a/triton_viz/core/patch.py b/triton_viz/core/patch.py index 2d48800e..603364af 100644 --- a/triton_viz/core/patch.py +++ b/triton_viz/core/patch.py @@ -36,6 +36,16 @@ _MISSING = object() +# Monkey-patch tl.constexpr to add .to() for interpreter mode. +# In compiled Triton, constexpr.to(dtype) works via the compiler pipeline, +# but in interpreter mode constexpr has no .to() method. +if not hasattr(tl.constexpr, "to"): + + def _constexpr_to(self, dtype, fp_downcast_rounding=None, bitcast=False): + return _implicit_cvt(self.value) + + tl.constexpr.to = _constexpr_to + class _LangPatchScope: """Tracks patched attributes so they can be restored.""" @@ -422,8 +432,10 @@ def _grid_executor_call(self, *args_dev, backend=None, **kwargs): call_args = {} for name, arg in args.items(): if name in self.constexprs: - call_args[name] = arg - ret = arg + call_args[name] = ( + tl.constexpr(arg) if isinstance(arg, (int, float, bool)) else arg + ) + ret = call_args[name] else: ret = _implicit_cvt(arg) client_manager.arg_callback(name, arg, ret) From 985609e51036ab8b0cbfc5d5cc402b2f5c7fde3c Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Wed, 18 Mar 2026 01:16:52 +0000 Subject: [PATCH 2/7] [FIX] Upgrade constexpr.to() with real cast semantics, scoped patching, universal normalization, and bool support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Replace no-op _implicit_cvt shim with numpy-based cast that handles truncation (float→int) and bitcast (reinterpret bits) - Move constexpr monkey-patch from module-level to scoped _patch_constexpr() via _LangPatchScope, with hasattr guards for upstream compatibility - Add __getattr__ proxy on constexpr for attribute access (e.g. DTYPE.is_int()) - Universally wrap all constexpr args via _normalize_constexpr_arg(), not just int/float/bool, enabling dtype/tuple/custom-type meta-params - Add bool to symbolic engine's builtin_scala_types with int1 dtype inference - Add E2E tests for real cast, bitcast, dtype meta-param, tuple meta-param, grid lambda, and bool cast - Add unit tests for bool symbolic inference and constexpr patch lifecycle --- tests/end_to_end/test_sanitizer.py | 95 +++++++++++++++++++++++++ tests/unit/test_patch_scope.py | 28 +++++++- tests/unit/test_symbolic_bool.py | 13 ++++ triton_viz/clients/symbolic_engine.py | 10 ++- triton_viz/core/patch.py | 99 ++++++++++++++++++++++++--- 5 files changed, 230 insertions(+), 15 deletions(-) create mode 100644 tests/unit/test_symbolic_bool.py diff --git a/tests/end_to_end/test_sanitizer.py b/tests/end_to_end/test_sanitizer.py index 690f1eae..a6c7190d 100644 --- a/tests/end_to_end/test_sanitizer.py +++ b/tests/end_to_end/test_sanitizer.py @@ -1190,6 +1190,101 @@ def test_reduce_on_batched_dot_result(): ), f"Expected no OOB records, got {len(reduce_dot_sanitizer.records)}" +@triton_viz.trace(client="tracer") +@triton.jit +def cast_truncate_kernel(out_ptr, v: tl.constexpr, N: tl.constexpr): + offs = tl.arange(0, N) + result = v.to(tl.int32) + tl.store(out_ptr + offs, result) + + +def test_constexpr_to_real_cast(): + """constexpr.to(tl.int32) must truncate, not no-op.""" + out = torch.empty(1, dtype=torch.int32) + cast_truncate_kernel[(1,)](out, v=1.9, N=1) + assert out.item() == 1 + + +@triton_viz.trace(client="tracer") +@triton.jit +def cast_bitcast_kernel(out_ptr, v: tl.constexpr, N: tl.constexpr): + offs = tl.arange(0, N) + result = v.to(tl.float32, bitcast=True) + tl.store(out_ptr + offs, result) + + +def test_constexpr_to_bitcast(): + """constexpr.to(tl.float32, bitcast=True) must reinterpret bits.""" + import struct + + out = torch.empty(1, dtype=torch.float32) + cast_bitcast_kernel[(1,)](out, v=42, N=1) + expected = struct.unpack("f", struct.pack("i", 42))[0] + assert out.item() == pytest.approx(expected) + + +@triton_viz.trace(client="tracer") +@triton.jit +def dtype_meta_kernel(out_ptr, DTYPE: tl.constexpr, N: tl.constexpr): + offs = tl.arange(0, N) + x = tl.full((N,), value=1.0, dtype=DTYPE) + tl.store(out_ptr + offs, x) + + +def test_constexpr_dtype_meta_param(): + """dtype passed as constexpr meta-param must work after universal wrapping.""" + out = torch.empty(8, dtype=torch.int32) + dtype_meta_kernel[(1,)](out, DTYPE=tl.int32, N=8) + assert (out == 1).all() + + +@triton_viz.trace(client="tracer") +@triton.jit +def tuple_meta_kernel(out_ptr, SHAPE: tl.constexpr): + offs = tl.arange(0, SHAPE[0]) + tl.store(out_ptr + offs, offs) + + +def test_constexpr_tuple_meta_param(): + """tuple meta-param constexpr must support __getitem__.""" + out = torch.empty(4, dtype=torch.int32) + tuple_meta_kernel[(1,)](out, SHAPE=(4,)) + assert (out == torch.arange(4, dtype=torch.int32)).all() + + +@triton_viz.trace(client="tracer") +@triton.jit +def grid_lambda_kernel(out_ptr, N: tl.constexpr, BLOCK: tl.constexpr): + pid = tl.program_id(0) + offs = pid * BLOCK + tl.arange(0, BLOCK) + mask = offs < N + tl.store(out_ptr + offs, offs, mask=mask) + + +def test_constexpr_grid_lambda(): + """Grid lambda with META dict must work with universally-wrapped constexprs.""" + N = 256 + out = torch.empty(N, dtype=torch.int32) + grid = lambda META: (triton.cdiv(N, META["BLOCK"]),) + grid_lambda_kernel[grid](out, N=N, BLOCK=128) + assert (out == torch.arange(N, dtype=torch.int32)).all() + + +@triton_viz.trace(client="tracer") +@triton.jit +def bool_constexpr_kernel(out_ptr, flag: tl.constexpr, N: tl.constexpr): + offs = tl.arange(0, N) + val = flag.to(tl.int1) + tl.store(out_ptr + offs, val) + + +def test_constexpr_bool_cast(): + """bool constexpr .to(tl.int1) must work.""" + out = torch.empty(1, dtype=torch.int32) + bool_constexpr_kernel[(1,)](out, flag=True, N=1) + assert out.item() == 1 + + def test_reinterpret_tensor_wrapper(): """triton.reinterpret() produces a TensorWrapper; sanitizer must handle it.""" N = 64 diff --git a/tests/unit/test_patch_scope.py b/tests/unit/test_patch_scope.py index 59fd418c..1403cb3e 100644 --- a/tests/unit/test_patch_scope.py +++ b/tests/unit/test_patch_scope.py @@ -3,7 +3,7 @@ import triton.language as tl from triton_viz.core import patch as patch_mod -from triton_viz.core.patch import _triton_snapshot_scope +from triton_viz.core.patch import _triton_snapshot_scope, patch_lang, unpatch_lang def _dummy_kernel(): @@ -72,3 +72,29 @@ def _getmembers(obj): scope.restore() assert getattr(descriptor, attr) is original + + +def test_constexpr_patch_lifecycle(): + """patch_lang must add .to/__getattr__ to constexpr; unpatch_lang must restore.""" + # Snapshot before-state + had_to = hasattr(tl.constexpr, "to") + had_getattr = hasattr(tl.constexpr, "__getattr__") + orig_to = getattr(tl.constexpr, "to", None) + orig_getattr = getattr(tl.constexpr, "__getattr__", None) + + patch_lang(_dummy_kernel, "triton") + try: + assert hasattr(tl.constexpr, "to") + assert callable(tl.constexpr.to) + assert hasattr(tl.constexpr, "__getattr__") + assert callable(tl.constexpr.__getattr__) + finally: + unpatch_lang("triton") + + # Verify exact before-state is restored + assert hasattr(tl.constexpr, "to") == had_to + assert hasattr(tl.constexpr, "__getattr__") == had_getattr + if had_to: + assert getattr(tl.constexpr, "to") is orig_to + if had_getattr: + assert getattr(tl.constexpr, "__getattr__") is orig_getattr diff --git a/tests/unit/test_symbolic_bool.py b/tests/unit/test_symbolic_bool.py new file mode 100644 index 00000000..551cd73a --- /dev/null +++ b/tests/unit/test_symbolic_bool.py @@ -0,0 +1,13 @@ +import triton.language as tl + +from triton_viz.clients.symbolic_engine import SymbolicExpr + + +def test_bool_inferred_as_int1(): + assert SymbolicExpr.from_value(True).dtype is tl.int1 + assert SymbolicExpr.from_value(False).dtype is tl.int1 + + +def test_constexpr_bool_inferred_as_int1(): + assert SymbolicExpr.from_value(tl.constexpr(True)).dtype is tl.int1 + assert SymbolicExpr.from_value(tl.constexpr(False)).dtype is tl.int1 diff --git a/triton_viz/clients/symbolic_engine.py b/triton_viz/clients/symbolic_engine.py index 668ed743..882ef974 100644 --- a/triton_viz/clients/symbolic_engine.py +++ b/triton_viz/clients/symbolic_engine.py @@ -429,7 +429,7 @@ def _node_label_core(self) -> str: tl.float32, tl.float64, ) - builtin_scala_types: ClassVar[tuple[type, ...]] = (int, float) + builtin_scala_types: ClassVar[tuple[type, ...]] = (bool, int, float) tuple_types: ClassVar[tuple[type, ...]] = (tl.core.tuple, tuple, list) @staticmethod @@ -458,8 +458,12 @@ def _infer_literal_dtype(var: Any) -> tl.core.dtype | tl.pointer_type: f"All elements in the tuple must have the same dtype, but found {first_dtype} and {SymbolicExpr.from_value(v).dtype}" ) return first_dtype - if isinstance(var, SymbolicExpr.builtin_scala_types): - return tl.int32 if isinstance(var, int) else tl.float32 + if isinstance(var, bool): + return tl.int1 + if isinstance(var, int): + return tl.int32 + if isinstance(var, float): + return tl.float32 raise ValueError(f"Unsupported type: {type(var)}") # Stored on the class and may be accessed through either the class or an instance; diff --git a/triton_viz/core/patch.py b/triton_viz/core/patch.py index f6ee8a8e..499c1e12 100644 --- a/triton_viz/core/patch.py +++ b/triton_viz/core/patch.py @@ -18,6 +18,7 @@ RawStore, ) import inspect +import numpy as np from triton.runtime.interpreter import ( GridExecutor, _implicit_cvt, @@ -36,15 +37,72 @@ _MISSING = object() -# Monkey-patch tl.constexpr to add .to() for interpreter mode. -# In compiled Triton, constexpr.to(dtype) works via the compiler pipeline, -# but in interpreter mode constexpr has no .to() method. -if not hasattr(tl.constexpr, "to"): +_TRITON_DTYPE_TO_NP: dict[tl.core.dtype, type[np.generic]] = { + tl.int1: np.bool_, + tl.int8: np.int8, + tl.int16: np.int16, + tl.int32: np.int32, + tl.int64: np.int64, + tl.uint8: np.uint8, + tl.uint16: np.uint16, + tl.uint32: np.uint32, + tl.uint64: np.uint64, + tl.float16: np.float16, + tl.float32: np.float32, + tl.float64: np.float64, +} + + +def _src_np_dtype(value: object) -> np.dtype: + """Infer the numpy dtype for a Python literal, matching Triton's to_tensor logic.""" + if isinstance(value, bool): + return np.dtype(np.bool_) + if isinstance(value, int): + if -(2**31) <= value < 2**31: + return np.dtype(np.int32) + if 2**31 <= value < 2**32: + return np.dtype(np.uint32) + if -(2**63) <= value < 2**63: + return np.dtype(np.int64) + return np.dtype(np.uint64) + return np.dtype(np.float64) + + +def _constexpr_to(self, dtype, fp_downcast_rounding=None, bitcast=False): + """Interpreter-mode implementation of constexpr.to(dtype). + + Computes the cast with numpy directly (not via interpreter_semantic) + to avoid re-entering patched tl ops when a client like the sanitizer + is active. + """ + value = self.value if not isinstance(self.value, tl.constexpr) else self.value.value + dtype = dtype.value if isinstance(dtype, tl.constexpr) else dtype + fp_downcast_rounding = ( + fp_downcast_rounding.value + if isinstance(fp_downcast_rounding, tl.constexpr) + else fp_downcast_rounding + ) + bitcast_val = bitcast.value if isinstance(bitcast, tl.constexpr) else bitcast - def _constexpr_to(self, dtype, fp_downcast_rounding=None, bitcast=False): - return _implicit_cvt(self.value) + dst_np = np.dtype(_TRITON_DTYPE_TO_NP[dtype]) + + if bitcast_val: + src_np = _src_np_dtype(value) + raw = np.array([value], dtype=src_np).tobytes()[: dst_np.itemsize] + result = np.frombuffer(raw, dtype=dst_np)[0] + else: + result = dst_np.type(value) - tl.constexpr.to = _constexpr_to + py_val = result.item() + # bool must become int so _implicit_cvt doesn't hit TensorHandle int1 bug + if isinstance(py_val, bool): + py_val = int(py_val) + return _implicit_cvt(py_val) + + +def _constexpr_getattr(self, name): + """Proxy attribute access to the wrapped value for interpreter mode.""" + return getattr(self.value, name) class _LangPatchScope: @@ -67,6 +125,27 @@ def restore(self) -> None: setattr(obj, name, original) +def _patch_constexpr(scope: _LangPatchScope) -> None: + """Patch tl.constexpr with .to() and __getattr__ for interpreter mode.""" + if not hasattr(tl.constexpr, "to"): + scope.set_attr(tl.constexpr, "to", _constexpr_to) + if not hasattr(tl.constexpr, "__getattr__"): + scope.set_attr(tl.constexpr, "__getattr__", _constexpr_getattr) + + +def _normalize_constexpr_arg(arg): + """Wrap a kernel argument in tl.constexpr if it isn't already one. + + None is left as-is because Triton APIs accept ``constexpr | None`` + (e.g. ``reduce(..., dtype=None)``). + """ + if isinstance(arg, tl.constexpr): + return arg + if arg is None: + return None + return tl.constexpr(arg) + + _LANG_PATCH_SCOPES: dict[str, list[Any]] = {"triton": [], "nki": [], "nki_beta2": []} @@ -335,6 +414,7 @@ def patch_lang(fn, backend, client_manager=None): if backend == "triton": scope = _triton_snapshot_scope(fn) triton_patch_lang(fn) + _patch_constexpr(scope) elif backend == "nki": from triton_viz.core.nki import nki_patch_lang @@ -458,10 +538,7 @@ def _grid_executor_call(self, *args_dev, backend=None, **kwargs): call_args = {} for name, arg in args.items(): if name in self.constexprs: - call_args[name] = ( - tl.constexpr(arg) if isinstance(arg, (int, float, bool)) else arg - ) - ret = call_args[name] + ret = _normalize_constexpr_arg(arg) else: ret = _implicit_cvt(arg) client_manager.arg_callback(name, arg, ret) From 849f232e418aaca95f2842ab731eba847c580056 Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Wed, 18 Mar 2026 23:57:43 +0000 Subject: [PATCH 3/7] [FIX] Validate bitwidth equality for bitcast in constexpr.to() Triton's semantic.cast() raises ValueError when source and destination primitive bitwidths differ for bitcast. Our interpreter-mode _constexpr_to was silently truncating via a bytes slice instead. Add an explicit size check and remove the unnecessary slice. --- tests/end_to_end/test_sanitizer.py | 90 ++++++++++++++++++++++ triton_viz/core/patch.py | 117 ++++++++++++++++++++++++----- 2 files changed, 189 insertions(+), 18 deletions(-) diff --git a/tests/end_to_end/test_sanitizer.py b/tests/end_to_end/test_sanitizer.py index 1ec917c3..e73ab1b7 100644 --- a/tests/end_to_end/test_sanitizer.py +++ b/tests/end_to_end/test_sanitizer.py @@ -1378,3 +1378,93 @@ def test_reinterpret_tensor_wrapper(): x = torch.ones(N, dtype=torch.float16, device="cpu") y = torch.empty(N, dtype=torch.float16, device="cpu") copy_kernel[(1,)](triton.reinterpret(x, tl.float16), y, N, BLOCK=64) + + +# ======== constexpr.to() Regression Tests =========== + + +@triton_viz.trace(client="tracer") +@triton.jit +def bitcast_float_to_int_kernel(out_ptr, v: tl.constexpr, N: tl.constexpr): + offs = tl.arange(0, N) + result = v.to(tl.int32, bitcast=True) + tl.store(out_ptr + offs, result) + + +def test_constexpr_bitcast_float_src(): + """bitcast of float constexpr must use float32 source, not float64.""" + import struct + + out = torch.empty(1, dtype=torch.int32) + bitcast_float_to_int_kernel[(1,)](out, v=1.0, N=1) + expected = struct.unpack("i", struct.pack("f", 1.0))[0] # 0x3f800000 + assert out.item() == expected + + +@triton_viz.trace(client="tracer") +@triton.jit +def fp_downcast_kernel( + out_ptr, v: tl.constexpr, rounding: tl.constexpr, N: tl.constexpr +): + offs = tl.arange(0, N) + result = v.to(tl.float16, fp_downcast_rounding=rounding) + tl.store(out_ptr + offs, result) + + +def test_constexpr_fp_downcast_rtz(): + """fp_downcast_rounding='rtz' must truncate toward zero, not round to nearest.""" + # 1.00146484375 (fp32) is the exact midpoint between fp16 values + # 1.0009765625 and 1.001953125. + # rtne rounds UP to 1.001953125; rtz truncates DOWN to 1.0009765625. + v = 1.00146484375 + out_rtz = torch.empty(1, dtype=torch.float16) + fp_downcast_kernel[(1,)](out_rtz, v=v, rounding="rtz", N=1) + out_rtne = torch.empty(1, dtype=torch.float16) + fp_downcast_kernel[(1,)](out_rtne, v=v, rounding="rtne", N=1) + assert float(out_rtz.item()) == pytest.approx(1.0009765625) + assert float(out_rtne.item()) == pytest.approx(1.001953125) + + +@triton_viz.trace(client="tracer") +@triton.jit +def bool_mask_kernel(out_ptr, flag: tl.constexpr, N: tl.constexpr): + offs = tl.arange(0, N) + cond = flag.to(tl.int1) + result = tl.where( + cond, tl.full((N,), 42, dtype=tl.int32), tl.full((N,), -1, dtype=tl.int32) + ) + tl.store(out_ptr + offs, result) + + +def test_constexpr_bool_as_mask(): + """bool constexpr .to(tl.int1) must work as a predicate in tl.where.""" + out = torch.empty(4, dtype=torch.int32) + bool_mask_kernel[(1,)](out, flag=True, N=4) + assert (out == 42).all() + bool_mask_kernel[(1,)](out, flag=False, N=4) + assert (out == -1).all() + + +def test_constexpr_bitcast_mismatched_size_raises(): + """bitcast between types of different sizes must raise ValueError.""" + from triton_viz.core.patch import _constexpr_to + import types + + mock_self = types.SimpleNamespace(value=1.0) + # float(1.0) -> float32 (4 bytes), tl.int64 (8 bytes) -> mismatch + with pytest.raises(ValueError, match="different sizes"): + _constexpr_to(mock_self, tl.int64, bitcast=True) + + +def test_constexpr_to_unsupported_dtype_raises(): + """Casting constexpr to an exotic dtype without numpy equivalent must raise, not KeyError.""" + from triton_viz.core.patch import _constexpr_to + import types + + mock_self = types.SimpleNamespace(value=1.0) + for fp8_name in ["float8e4nv", "float8e5"]: + dt = getattr(tl, fp8_name, None) + if dt is None: + continue + with pytest.raises(TypeError, match="not supported"): + _constexpr_to(mock_self, dt) diff --git a/triton_viz/core/patch.py b/triton_viz/core/patch.py index 499c1e12..55a01947 100644 --- a/triton_viz/core/patch.py +++ b/triton_viz/core/patch.py @@ -22,6 +22,7 @@ from triton.runtime.interpreter import ( GridExecutor, _implicit_cvt, + _get_np_dtype, interpreter_builder, ) from triton.runtime.interpreter import _patch_lang as triton_patch_lang @@ -37,21 +38,22 @@ _MISSING = object() -_TRITON_DTYPE_TO_NP: dict[tl.core.dtype, type[np.generic]] = { - tl.int1: np.bool_, - tl.int8: np.int8, - tl.int16: np.int16, - tl.int32: np.int32, - tl.int64: np.int64, - tl.uint8: np.uint8, - tl.uint16: np.uint16, - tl.uint32: np.uint32, - tl.uint64: np.uint64, - tl.float16: np.float16, - tl.float32: np.float32, - tl.float64: np.float64, +_F32_MIN_NORMAL = 2**-126 +_F32_MAX = (2 - 2**-23) * 2**127 + +_RTZ_STORAGE: dict[np.dtype, type[np.unsignedinteger]] = { + np.dtype(np.float16): np.uint16, + np.dtype(np.float32): np.uint32, + np.dtype(np.float64): np.uint64, } +try: + import ml_dtypes as _ml_dtypes + + _RTZ_STORAGE[np.dtype(_ml_dtypes.bfloat16)] = np.uint16 +except ImportError: + _ml_dtypes = None # type: ignore[assignment] + def _src_np_dtype(value: object) -> np.dtype: """Infer the numpy dtype for a Python literal, matching Triton's to_tensor logic.""" @@ -65,9 +67,58 @@ def _src_np_dtype(value: object) -> np.dtype: if -(2**63) <= value < 2**63: return np.dtype(np.int64) return np.dtype(np.uint64) + # Float: match Triton to_tensor() — float32 for representable values + assert isinstance(value, float) + abs_x: float = abs(value) + if ( + abs_x == 0.0 + or abs_x != abs_x + or abs_x == float("inf") + or (_F32_MIN_NORMAL <= abs_x <= _F32_MAX) + ): + return np.dtype(np.float32) return np.dtype(np.float64) +def _cast_np_dtype(triton_dtype: tl.core.dtype) -> np.dtype: + """Numpy dtype for semantic cast. Raises for exotic types without numpy equivalents.""" + np_dt = _get_np_dtype(triton_dtype) + # _get_np_dtype maps bf16->uint16 and fp8->uint8 (storage types). + # For cast we need semantic types; bf16 is available via ml_dtypes. + if triton_dtype == tl.bfloat16: + if _ml_dtypes is None: + raise TypeError("constexpr.to(bfloat16) requires ml_dtypes to be installed") + return np.dtype(_ml_dtypes.bfloat16) + if np_dt in (np.uint8, np.uint16) and triton_dtype not in ( + tl.uint8, + tl.uint16, + tl.int1, + ): + raise TypeError( + f"constexpr.to({triton_dtype}) is not supported in interpreter mode " + f"(no numpy-compatible semantic type for {triton_dtype})" + ) + return np.dtype(np_dt) + + +def _fp_downcast_rtz(value: float, dst_np_dtype: np.dtype) -> np.generic: + """Round-toward-zero float downcast by correcting rtne when it rounds away from zero.""" + rtne_result = dst_np_dtype.type(value) + if abs(float(rtne_result)) > abs(value): + # rtne rounded away from zero -- subtract 1 ULP + storage = _RTZ_STORAGE[dst_np_dtype] + bits = np.frombuffer( + np.array([rtne_result], dtype=dst_np_dtype).tobytes(), + dtype=storage, + )[0] + bits -= np.array(1, dtype=storage) + return np.frombuffer( + np.array([bits], dtype=storage).tobytes(), + dtype=dst_np_dtype, + )[0] + return rtne_result + + def _constexpr_to(self, dtype, fp_downcast_rounding=None, bitcast=False): """Interpreter-mode implementation of constexpr.to(dtype). @@ -84,17 +135,47 @@ def _constexpr_to(self, dtype, fp_downcast_rounding=None, bitcast=False): ) bitcast_val = bitcast.value if isinstance(bitcast, tl.constexpr) else bitcast - dst_np = np.dtype(_TRITON_DTYPE_TO_NP[dtype]) - if bitcast_val: src_np = _src_np_dtype(value) - raw = np.array([value], dtype=src_np).tobytes()[: dst_np.itemsize] + dst_np = np.dtype(_get_np_dtype(dtype)) + if src_np.itemsize != dst_np.itemsize: + raise ValueError( + f"Cannot bitcast between types of different sizes: " + f"{src_np} ({src_np.itemsize * 8} bits) and " + f"{dst_np} ({dst_np.itemsize * 8} bits)" + ) + raw = np.array([value], dtype=src_np).tobytes() result = np.frombuffer(raw, dtype=dst_np)[0] else: - result = dst_np.type(value) + dst_np = _cast_np_dtype(dtype) + src_np = _src_np_dtype(value) + + if fp_downcast_rounding is not None: + src_is_float = src_np.kind == "f" + dst_is_float = dst_np.kind == "f" or ( + _ml_dtypes is not None and dst_np == np.dtype(_ml_dtypes.bfloat16) + ) + if not ( + dst_is_float and src_is_float and dst_np.itemsize < src_np.itemsize + ): + raise ValueError( + f"fp_downcast_rounding is only valid for float-to-smaller-float casts, " + f"got {src_np} -> {dst_np}" + ) + + if fp_downcast_rounding == "rtz": + result = _fp_downcast_rtz(value, dst_np) + else: + result = dst_np.type(value) py_val = result.item() - # bool must become int so _implicit_cvt doesn't hit TensorHandle int1 bug + # bool/int1: Triton's TensorHandle.__post_init__ rejects np.array([True], dtype=np.int32) + # paired with tl.int1 because itemsize(int32)=32 > primitive_bitwidth(int1)=1. + # _implicit_cvt(True) also hits this (bool is subclass of int, mangle_type->"i1"). + # Workaround: coerce to Python int so _implicit_cvt creates an int32 tensor. + # This is value-correct (0/1) though not type-correct (int32, not int1). + # The interpreter itself has the same limitation -- to_tensor(True) goes through + # a similar path. A proper fix belongs in upstream Triton's TensorHandle. if isinstance(py_val, bool): py_val = int(py_val) return _implicit_cvt(py_val) From 281835e1b16e4492b4784ea4368cd63fd463f0f5 Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Thu, 19 Mar 2026 01:33:07 +0000 Subject: [PATCH 4/7] [TEST] Add failing tests for constexpr.to() dtype preservation Unit tests assert that _constexpr_to returns a tensor with the exact destination dtype (uint8, int16, uint32, float64). E2E test verifies unsigned integer division with a constexpr denominator. --- tests/end_to_end/test_sanitizer.py | 18 +++++++++++++++++ tests/unit/test_constexpr_to.py | 32 ++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+) create mode 100644 tests/unit/test_constexpr_to.py diff --git a/tests/end_to_end/test_sanitizer.py b/tests/end_to_end/test_sanitizer.py index e73ab1b7..a5b67a4c 100644 --- a/tests/end_to_end/test_sanitizer.py +++ b/tests/end_to_end/test_sanitizer.py @@ -1468,3 +1468,21 @@ def test_constexpr_to_unsupported_dtype_raises(): continue with pytest.raises(TypeError, match="not supported"): _constexpr_to(mock_self, dt) + + +# ======== Unsigned Division Regression Test =========== + + +@triton_viz.trace(client="tracer") +@triton.jit +def unsigned_div_kernel(out_ptr, DENOM: tl.constexpr): + numerator = tl.full((1,), 8, tl.uint32) + denom = DENOM.to(tl.uint32) + out = numerator // denom + tl.store(out_ptr, out) + + +def test_constexpr_unsigned_div(): + out = torch.empty(1, dtype=torch.int32) + unsigned_div_kernel[(1,)](out, DENOM=1) + assert out.item() == 8 diff --git a/tests/unit/test_constexpr_to.py b/tests/unit/test_constexpr_to.py new file mode 100644 index 00000000..a8a64f3a --- /dev/null +++ b/tests/unit/test_constexpr_to.py @@ -0,0 +1,32 @@ +import types + +import pytest +import numpy as np +import triton.language as tl + +from triton_viz.core.patch import _constexpr_to + + +@pytest.mark.parametrize( + "value,triton_dtype,expected_np,expected_val", + [ + (1.0, tl.float64, np.float64, 1.0), # float64 demoted to float32 + (1, tl.uint8, np.uint8, 1), # uint8 demoted to int32 + (-7, tl.int16, np.int16, -7), # int16 demoted to int32 + (1, tl.uint32, np.uint32, 1), # uint32 demoted to int32 + ], +) +def test_constexpr_to_preserves_dtype(value, triton_dtype, expected_np, expected_val): + mock_self = types.SimpleNamespace(value=value) + ret = _constexpr_to(mock_self, triton_dtype) + assert isinstance(ret, tl.core.tensor) + assert ret.dtype == triton_dtype + assert ret.handle.data.dtype == np.dtype(expected_np) + assert ret.handle.data.item() == expected_val + + +def test_constexpr_bitcast_preserves_dtype(): + mock_self = types.SimpleNamespace(value=42) + ret = _constexpr_to(mock_self, tl.float32, bitcast=True) + assert isinstance(ret, tl.core.tensor) + assert ret.dtype == tl.float32 From 17cc5b25cd77b2124f5b92025f26803fd5abf7d6 Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Thu, 19 Mar 2026 01:35:15 +0000 Subject: [PATCH 5/7] [FIX] Preserve explicit destination dtype in constexpr.to() Replace _implicit_cvt return path with _typed_scalar_tensor helper that wraps the cast result using the exact target dtype. This fixes dtype demotion for uint8, int16, uint32, float64, and bool/int1 casts. --- tests/end_to_end/test_sanitizer.py | 2 +- triton_viz/core/patch.py | 23 ++++++++++++----------- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/tests/end_to_end/test_sanitizer.py b/tests/end_to_end/test_sanitizer.py index a5b67a4c..d65693b3 100644 --- a/tests/end_to_end/test_sanitizer.py +++ b/tests/end_to_end/test_sanitizer.py @@ -1479,7 +1479,7 @@ def unsigned_div_kernel(out_ptr, DENOM: tl.constexpr): numerator = tl.full((1,), 8, tl.uint32) denom = DENOM.to(tl.uint32) out = numerator // denom - tl.store(out_ptr, out) + tl.store(out_ptr + tl.arange(0, 1), out) def test_constexpr_unsigned_div(): diff --git a/triton_viz/core/patch.py b/triton_viz/core/patch.py index 55a01947..29414384 100644 --- a/triton_viz/core/patch.py +++ b/triton_viz/core/patch.py @@ -21,6 +21,7 @@ import numpy as np from triton.runtime.interpreter import ( GridExecutor, + TensorHandle, _implicit_cvt, _get_np_dtype, interpreter_builder, @@ -119,6 +120,16 @@ def _fp_downcast_rtz(value: float, dst_np_dtype: np.dtype) -> np.generic: return rtne_result +def _typed_scalar_tensor(result: np.generic, dtype: tl.core.dtype) -> tl.core.tensor: + """Wrap a numpy scalar as a Triton interpreter tensor with an exact dtype.""" + storage_np = _get_np_dtype(dtype) + arr = np.array([result]) + if arr.dtype != storage_np: + # Semantic/storage dtype differ (e.g. ml_dtypes.bfloat16 -> uint16) + arr = arr.view(storage_np) + return tl.core.tensor(TensorHandle(arr, dtype), dtype) + + def _constexpr_to(self, dtype, fp_downcast_rounding=None, bitcast=False): """Interpreter-mode implementation of constexpr.to(dtype). @@ -168,17 +179,7 @@ def _constexpr_to(self, dtype, fp_downcast_rounding=None, bitcast=False): else: result = dst_np.type(value) - py_val = result.item() - # bool/int1: Triton's TensorHandle.__post_init__ rejects np.array([True], dtype=np.int32) - # paired with tl.int1 because itemsize(int32)=32 > primitive_bitwidth(int1)=1. - # _implicit_cvt(True) also hits this (bool is subclass of int, mangle_type->"i1"). - # Workaround: coerce to Python int so _implicit_cvt creates an int32 tensor. - # This is value-correct (0/1) though not type-correct (int32, not int1). - # The interpreter itself has the same limitation -- to_tensor(True) goes through - # a similar path. A proper fix belongs in upstream Triton's TensorHandle. - if isinstance(py_val, bool): - py_val = int(py_val) - return _implicit_cvt(py_val) + return _typed_scalar_tensor(result, dtype) def _constexpr_getattr(self, name): From a78d02c1b20c6e4019d9f109f5b21b37a08d6539 Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Thu, 19 Mar 2026 19:51:50 +0000 Subject: [PATCH 6/7] [FIX] Validate fp_downcast_rounding, guard integer overflow, feed raw values to grid lambda - Reject invalid fp_downcast_rounding values (e.g. "RTZ", "foo") with ValueError instead of silently falling back to rtne. - Raise OverflowError for integer literals outside [-2**63, 2**64) in _src_np_dtype, matching upstream Triton's representable range. - Pass raw host values (not tl.constexpr wrappers) to grid lambdas so isinstance checks in grid/heuristic functions behave as in real Triton. --- tests/end_to_end/test_sanitizer.py | 12 ++++++++++-- tests/unit/test_constexpr_to.py | 22 +++++++++++++++++++++- triton_viz/core/patch.py | 22 ++++++++++++++++++++-- 3 files changed, 51 insertions(+), 5 deletions(-) diff --git a/tests/end_to_end/test_sanitizer.py b/tests/end_to_end/test_sanitizer.py index d65693b3..b1770de1 100644 --- a/tests/end_to_end/test_sanitizer.py +++ b/tests/end_to_end/test_sanitizer.py @@ -1349,10 +1349,18 @@ def grid_lambda_kernel(out_ptr, N: tl.constexpr, BLOCK: tl.constexpr): def test_constexpr_grid_lambda(): - """Grid lambda with META dict must work with universally-wrapped constexprs.""" + """Grid lambda receives raw host values, not tl.constexpr wrappers.""" N = 256 out = torch.empty(N, dtype=torch.int32) - grid = lambda META: (triton.cdiv(N, META["BLOCK"]),) + + def grid(META): + # Upstream Triton feeds raw Python values to grid lambdas; + # verify isinstance checks work as they would in real Triton. + assert isinstance( + META["BLOCK"], int + ), f"grid lambda should see raw int, got {type(META['BLOCK'])}" + return (triton.cdiv(N, META["BLOCK"]),) + grid_lambda_kernel[grid](out, N=N, BLOCK=128) assert (out == torch.arange(N, dtype=torch.int32)).all() diff --git a/tests/unit/test_constexpr_to.py b/tests/unit/test_constexpr_to.py index a8a64f3a..c3a6befb 100644 --- a/tests/unit/test_constexpr_to.py +++ b/tests/unit/test_constexpr_to.py @@ -4,7 +4,7 @@ import numpy as np import triton.language as tl -from triton_viz.core.patch import _constexpr_to +from triton_viz.core.patch import _constexpr_to, _src_np_dtype @pytest.mark.parametrize( @@ -30,3 +30,23 @@ def test_constexpr_bitcast_preserves_dtype(): ret = _constexpr_to(mock_self, tl.float32, bitcast=True) assert isinstance(ret, tl.core.tensor) assert ret.dtype == tl.float32 + + +@pytest.mark.parametrize("bad_mode", ["RTZ", "foo", ""]) +def test_constexpr_to_invalid_rounding_mode_raises(bad_mode): + """Invalid fp_downcast_rounding values must raise ValueError, not silently fall back.""" + mock_self = types.SimpleNamespace(value=1.0) + with pytest.raises(ValueError, match="fp_downcast_rounding must be one of"): + _constexpr_to(mock_self, tl.float16, fp_downcast_rounding=bad_mode) + + +def test_constexpr_to_overflow_raises(): + """Integer literals outside [-2**63, 2**64) must raise OverflowError.""" + mock_self = types.SimpleNamespace(value=-(2**63) - 1) + with pytest.raises(OverflowError, match="outside the representable range"): + _constexpr_to(mock_self, tl.int64) + + +def test_src_np_dtype_max_uint64(): + """Values in [2**63, 2**64) must map to uint64, not overflow.""" + assert _src_np_dtype(2**64 - 1) == np.dtype(np.uint64) diff --git a/triton_viz/core/patch.py b/triton_viz/core/patch.py index 29414384..d4637760 100644 --- a/triton_viz/core/patch.py +++ b/triton_viz/core/patch.py @@ -67,7 +67,12 @@ def _src_np_dtype(value: object) -> np.dtype: return np.dtype(np.uint32) if -(2**63) <= value < 2**63: return np.dtype(np.int64) - return np.dtype(np.uint64) + if 0 <= value < 2**64: + return np.dtype(np.uint64) + raise OverflowError( + f"Integer literal {value} is outside the representable range " + f"[-2**63, 2**64) for Triton integer types" + ) # Float: match Triton to_tensor() — float32 for representable values assert isinstance(value, float) abs_x: float = abs(value) @@ -162,6 +167,12 @@ def _constexpr_to(self, dtype, fp_downcast_rounding=None, bitcast=False): src_np = _src_np_dtype(value) if fp_downcast_rounding is not None: + _VALID_ROUNDING_MODES = ("rtne", "rtz") + if fp_downcast_rounding not in _VALID_ROUNDING_MODES: + raise ValueError( + f"fp_downcast_rounding must be one of {_VALID_ROUNDING_MODES}, " + f"got {fp_downcast_rounding!r}" + ) src_is_float = src_np.kind == "f" dst_is_float = dst_np.kind == "f" or ( _ml_dtypes is not None and dst_np == np.dtype(_ml_dtypes.bfloat16) @@ -617,6 +628,13 @@ def _grid_executor_call(self, *args_dev, backend=None, **kwargs): # Prepare call arguments args = inspect.getcallargs(self.fn, *args_hst, **kwargs_hst) + # grid_args keeps constexprs as raw host values (matching upstream Triton) + # so that grid lambdas see plain Python types, not tl.constexpr wrappers. + grid_args = { + name: (arg if name in self.constexprs else _implicit_cvt(arg)) + for name, arg in args.items() + } + grid_args.pop("self", None) call_args = {} for name, arg in args.items(): if name in self.constexprs: @@ -627,7 +645,7 @@ def _grid_executor_call(self, *args_dev, backend=None, **kwargs): call_args[name] = ret call_args.pop("self", None) # Iterate through grid - grid = self.grid(call_args) if callable(self.grid) else self.grid + grid = self.grid(grid_args) if callable(self.grid) else self.grid assert len(grid) <= 3 grid = grid + (1,) * (3 - len(grid)) From c1544e9749d8e660dcfe2d1d3ea89250dd28fce1 Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Sat, 21 Mar 2026 04:22:46 +0000 Subject: [PATCH 7/7] [FIX] Fix integer cast wrapping semantics and bitcast primitive_bitwidth validation - Replace dst_np.type(value) with np.array(src).astype(dst) to get proper wrap/truncate semantics matching upstream Triton's cast_impl (e.g. 300->uint8=44, -1->uint32=4294967295). - Validate bitcast using Triton primitive_bitwidth instead of numpy storage itemsize, so bool(int1)->int8 bitcast correctly rejects (1 bit != 8 bits) instead of passing on storage size equality. - Extract _src_triton_dtype() mirroring upstream to_tensor type rules; _src_np_dtype() now delegates to it. --- tests/end_to_end/test_sanitizer.py | 4 +-- tests/unit/test_constexpr_to.py | 20 ++++++++++++--- triton_viz/core/patch.py | 40 +++++++++++++++++------------- 3 files changed, 41 insertions(+), 23 deletions(-) diff --git a/tests/end_to_end/test_sanitizer.py b/tests/end_to_end/test_sanitizer.py index b1770de1..2af88e87 100644 --- a/tests/end_to_end/test_sanitizer.py +++ b/tests/end_to_end/test_sanitizer.py @@ -1459,8 +1459,8 @@ def test_constexpr_bitcast_mismatched_size_raises(): import types mock_self = types.SimpleNamespace(value=1.0) - # float(1.0) -> float32 (4 bytes), tl.int64 (8 bytes) -> mismatch - with pytest.raises(ValueError, match="different sizes"): + # float(1.0) -> float32 (32 bits), tl.int64 (64 bits) -> mismatch + with pytest.raises(ValueError, match="Cannot bitcast"): _constexpr_to(mock_self, tl.int64, bitcast=True) diff --git a/tests/unit/test_constexpr_to.py b/tests/unit/test_constexpr_to.py index c3a6befb..8f013f4f 100644 --- a/tests/unit/test_constexpr_to.py +++ b/tests/unit/test_constexpr_to.py @@ -10,10 +10,15 @@ @pytest.mark.parametrize( "value,triton_dtype,expected_np,expected_val", [ - (1.0, tl.float64, np.float64, 1.0), # float64 demoted to float32 - (1, tl.uint8, np.uint8, 1), # uint8 demoted to int32 - (-7, tl.int16, np.int16, -7), # int16 demoted to int32 - (1, tl.uint32, np.uint32, 1), # uint32 demoted to int32 + (1.0, tl.float64, np.float64, 1.0), + (1, tl.uint8, np.uint8, 1), + (-7, tl.int16, np.int16, -7), + (1, tl.uint32, np.uint32, 1), + # Narrowing / signedness-change casts — must wrap, not reject + (300, tl.uint8, np.uint8, 300 % 256), # 44 + (-1, tl.uint32, np.uint32, 2**32 - 1), # 4294967295 + (2**31, tl.int32, np.int32, -(2**31)), # signed wrap + (255, tl.int8, np.int8, -1), # unsigned → signed wrap ], ) def test_constexpr_to_preserves_dtype(value, triton_dtype, expected_np, expected_val): @@ -47,6 +52,13 @@ def test_constexpr_to_overflow_raises(): _constexpr_to(mock_self, tl.int64) +def test_constexpr_bitcast_bool_to_int8_raises(): + """Bitcast bool(int1) -> int8 must raise: primitive_bitwidth 1 != 8.""" + mock_self = types.SimpleNamespace(value=True) + with pytest.raises(ValueError, match="Cannot bitcast"): + _constexpr_to(mock_self, tl.int8, bitcast=True) + + def test_src_np_dtype_max_uint64(): """Values in [2**63, 2**64) must map to uint64, not overflow.""" assert _src_np_dtype(2**64 - 1) == np.dtype(np.uint64) diff --git a/triton_viz/core/patch.py b/triton_viz/core/patch.py index d4637760..69d4a118 100644 --- a/triton_viz/core/patch.py +++ b/triton_viz/core/patch.py @@ -56,24 +56,23 @@ _ml_dtypes = None # type: ignore[assignment] -def _src_np_dtype(value: object) -> np.dtype: - """Infer the numpy dtype for a Python literal, matching Triton's to_tensor logic.""" +def _src_triton_dtype(value: object) -> tl.core.dtype: + """Infer the Triton dtype for a Python literal, mirroring upstream to_tensor.""" if isinstance(value, bool): - return np.dtype(np.bool_) + return tl.int1 if isinstance(value, int): if -(2**31) <= value < 2**31: - return np.dtype(np.int32) + return tl.int32 if 2**31 <= value < 2**32: - return np.dtype(np.uint32) + return tl.uint32 if -(2**63) <= value < 2**63: - return np.dtype(np.int64) + return tl.int64 if 0 <= value < 2**64: - return np.dtype(np.uint64) + return tl.uint64 raise OverflowError( f"Integer literal {value} is outside the representable range " f"[-2**63, 2**64) for Triton integer types" ) - # Float: match Triton to_tensor() — float32 for representable values assert isinstance(value, float) abs_x: float = abs(value) if ( @@ -82,8 +81,13 @@ def _src_np_dtype(value: object) -> np.dtype: or abs_x == float("inf") or (_F32_MIN_NORMAL <= abs_x <= _F32_MAX) ): - return np.dtype(np.float32) - return np.dtype(np.float64) + return tl.float32 + return tl.float64 + + +def _src_np_dtype(value: object) -> np.dtype: + """Infer the numpy dtype for a Python literal, matching Triton's to_tensor logic.""" + return np.dtype(_get_np_dtype(_src_triton_dtype(value))) def _cast_np_dtype(triton_dtype: tl.core.dtype) -> np.dtype: @@ -152,14 +156,16 @@ def _constexpr_to(self, dtype, fp_downcast_rounding=None, bitcast=False): bitcast_val = bitcast.value if isinstance(bitcast, tl.constexpr) else bitcast if bitcast_val: - src_np = _src_np_dtype(value) - dst_np = np.dtype(_get_np_dtype(dtype)) - if src_np.itemsize != dst_np.itemsize: + src_triton = _src_triton_dtype(value) + src_bits = src_triton.primitive_bitwidth + dst_bits = dtype.primitive_bitwidth + if src_bits != dst_bits: raise ValueError( - f"Cannot bitcast between types of different sizes: " - f"{src_np} ({src_np.itemsize * 8} bits) and " - f"{dst_np} ({dst_np.itemsize * 8} bits)" + f"Cannot bitcast data-type of size {src_bits} to " + f"data-type of size {dst_bits}" ) + src_np = _src_np_dtype(value) + dst_np = np.dtype(_get_np_dtype(dtype)) raw = np.array([value], dtype=src_np).tobytes() result = np.frombuffer(raw, dtype=dst_np)[0] else: @@ -188,7 +194,7 @@ def _constexpr_to(self, dtype, fp_downcast_rounding=None, bitcast=False): if fp_downcast_rounding == "rtz": result = _fp_downcast_rtz(value, dst_np) else: - result = dst_np.type(value) + result = np.array([value], dtype=src_np).astype(dst_np)[0] return _typed_scalar_tensor(result, dtype)