Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions tests/end_to_end/test_sanitizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1265,3 +1265,41 @@ def test_reinterpret_tensor_wrapper():
x = torch.ones(N, dtype=torch.float16, device="cpu")
y = torch.empty(N, dtype=torch.float16, device="cpu")
copy_kernel[(1,)](triton.reinterpret(x, tl.float16), y, N, BLOCK=64)


# ======== Constexpr Positional Argument Tests ===========

constexpr_str_annotation_sanitizer = SymbolicSanitizer(abort_on_error=False)


@triton_viz.trace(client=constexpr_str_annotation_sanitizer)
@triton.jit
def constexpr_str_annotation_kernel(X, Out, N, BLOCK_SIZE: "tl.constexpr"):
pid = tl.program_id(0)
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offs < N
x = tl.load(X + offs, mask=mask)
tl.store(Out + offs, x + 1.0, mask=mask)


def test_constexpr_string_annotation_positional_arg():
"""String-annotated tl.constexpr passed positionally must not raise ValueError.

On Python 3.14+ the AST rewriter double-quotes string annotations,
so the rewritten function's __annotations__ contains "'tl.constexpr'"
instead of "tl.constexpr", making GridExecutor.constexprs empty.
The fix derives constexpr names from the original JITFunction.params.
"""
constexpr_str_annotation_sanitizer.records.clear()

N = 16
BLOCK_SIZE = 4
x = torch.arange(N, dtype=torch.float32)
out = torch.empty_like(x)

# BLOCK_SIZE passed as positional arg with string annotation
constexpr_str_annotation_kernel[(N // BLOCK_SIZE,)](x, out, N, BLOCK_SIZE)

assert (
len(constexpr_str_annotation_sanitizer.records) == 0
), f"Expected no OOB records, got {len(constexpr_str_annotation_sanitizer.records)}"
14 changes: 12 additions & 2 deletions triton_viz/core/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,17 +437,27 @@ def _grid_executor_call(self, *args_dev, backend=None, **kwargs):
# Expose client_manager to tl.flip wrapper via a module-global
global _current_client_manager
_current_client_manager = client_manager
kwargs.pop("jit_fn")
jit_fn = kwargs.pop("jit_fn")
if cfg.virtual_memory:
args_hst, kwargs_hst = _init_args_hst(args_dev, kwargs)
else:
args_hst, kwargs_hst = self._init_args_hst(args_dev, kwargs)

# Derive constexpr names from the original JIT function, which has
# reliable annotation info. The rewritten function's __annotations__
# can be corrupted (string annotations get double-quoted), making
# self.constexprs unreliable.
# jit_fn is None for nested JIT calls and some Autotune paths.
if jit_fn is not None:
constexpr_names = {p.name for p in jit_fn.params if p.is_constexpr}
else:
constexpr_names = set(self.constexprs)

# Prepare call arguments
args = inspect.getcallargs(self.fn, *args_hst, **kwargs_hst)
call_args = {}
for name, arg in args.items():
if name in self.constexprs:
if name in constexpr_names:
call_args[name] = arg
ret = arg
else:
Expand Down
Loading