Skip to content

[FIX] Support constexpr .to() in interpreter mode#310

Open
mark14wu wants to merge 10 commits intomainfrom
fix/constexpr-to-method
Open

[FIX] Support constexpr .to() in interpreter mode#310
mark14wu wants to merge 10 commits intomainfrom
fix/constexpr-to-method

Conversation

@mark14wu
Copy link
Copy Markdown
Collaborator

@mark14wu mark14wu commented Mar 4, 2026

Summary

  • In compiled Triton, constexpr.to(dtype) works via the compiler pipeline, but in interpreter mode constexpr has no .to() method
  • Raw Python scalars passed as constexprs are not wrapped in tl.constexpr
  • Three coupled fixes: monkey-patch .to(), wrap scalars in tl.constexpr, unwrap in SymbolicExpr.from_value

Test plan

  • test_float_no_attr_toeps.to(dtype) where eps is a constexpr float

@github-actions
Copy link
Copy Markdown

github-actions bot commented Mar 4, 2026

Sanitizer Performance Benchmark

Benchmark main (min) PR (min) Change
gemm 0.188s 0.189s +0.9%
gemm_oob 0.194s 0.194s +0.1%
indirect_load 0.295s 0.299s +1.1%
nested_loop 0.378s 0.376s -0.3%
block_pointer_loop_advance 0.190s 0.186s -2.1%
liger_jsd 0.154s 0.155s +0.3%
flaggems_layernorm 0.463s 0.468s +1.3%
swiglu 0.188s 0.188s +0.0%
cross_entropy 0.173s 0.177s +2.3%
fused_linear_jsd 0.231s 0.238s +3.2%
Total 2.454s 2.471s +0.7%

Iterations: 1 warmup + 20 measured

@mark14wu mark14wu force-pushed the fix/constexpr-to-method branch from 6368485 to 357479a Compare March 4, 2026 03:08
In compiled Triton, constexpr.to(dtype) works via the compiler pipeline,
but in interpreter mode constexpr has no .to() method, and raw Python
scalars passed as constexprs are not wrapped.

Three coupled fixes:
- Monkey-patch tl.constexpr.to() for interpreter mode
- Wrap scalar constexprs in tl.constexpr in _grid_executor_call
- Unwrap tl.constexpr in SymbolicExpr.from_value to prevent type errors
@mark14wu mark14wu force-pushed the fix/constexpr-to-method branch from 357479a to 0524caf Compare March 4, 2026 03:19
# Conflicts:
#	tests/end_to_end/test_sanitizer.py
call_args[name] = arg
ret = arg
call_args[name] = (
tl.constexpr(arg) if isinstance(arg, (int, float, bool)) else arg
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why should you make it constexpr only for int, float, and bool?

…g, universal normalization, and bool support

- Replace no-op _implicit_cvt shim with numpy-based cast that handles
  truncation (float→int) and bitcast (reinterpret bits)
- Move constexpr monkey-patch from module-level to scoped _patch_constexpr()
  via _LangPatchScope, with hasattr guards for upstream compatibility
- Add __getattr__ proxy on constexpr for attribute access (e.g. DTYPE.is_int())
- Universally wrap all constexpr args via _normalize_constexpr_arg(), not just
  int/float/bool, enabling dtype/tuple/custom-type meta-params
- Add bool to symbolic engine's builtin_scala_types with int1 dtype inference
- Add E2E tests for real cast, bitcast, dtype meta-param, tuple meta-param,
  grid lambda, and bool cast
- Add unit tests for bool symbolic inference and constexpr patch lifecycle
Triton's semantic.cast() raises ValueError when source and destination
primitive bitwidths differ for bitcast. Our interpreter-mode
_constexpr_to was silently truncating via a bytes slice instead.
Add an explicit size check and remove the unnecessary slice.
Unit tests assert that _constexpr_to returns a tensor with the
exact destination dtype (uint8, int16, uint32, float64). E2E test
verifies unsigned integer division with a constexpr denominator.
Replace _implicit_cvt return path with _typed_scalar_tensor helper
that wraps the cast result using the exact target dtype. This fixes
dtype demotion for uint8, int16, uint32, float64, and bool/int1 casts.
… values to grid lambda

- Reject invalid fp_downcast_rounding values (e.g. "RTZ", "foo") with
  ValueError instead of silently falling back to rtne.
- Raise OverflowError for integer literals outside [-2**63, 2**64) in
  _src_np_dtype, matching upstream Triton's representable range.
- Pass raw host values (not tl.constexpr wrappers) to grid lambdas so
  isinstance checks in grid/heuristic functions behave as in real Triton.
…dth validation

- Replace dst_np.type(value) with np.array(src).astype(dst) to get
  proper wrap/truncate semantics matching upstream Triton's cast_impl
  (e.g. 300->uint8=44, -1->uint32=4294967295).
- Validate bitcast using Triton primitive_bitwidth instead of numpy
  storage itemsize, so bool(int1)->int8 bitcast correctly rejects
  (1 bit != 8 bits) instead of passing on storage size equality.
- Extract _src_triton_dtype() mirroring upstream to_tensor type rules;
  _src_np_dtype() now delegates to it.
def _patch_constexpr(scope: _LangPatchScope) -> None:
"""Patch tl.constexpr with .to() and __getattr__ for interpreter mode."""
if not hasattr(tl.constexpr, "to"):
scope.set_attr(tl.constexpr, "to", _constexpr_to)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There isn't a constexpr.to method. This is tensor.to

    @builtin
    def to(self, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcast: bool = False, _semantic=None):
        """
        Alias for :py:func:`tensor.cast`.
        """
        return cast(self, dtype, fp_downcast_rounding, bitcast, _semantic=_semantic)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants