Skip to content

[FIX] Patch libdevice stubs for interpreter/sanitizer mode#305

Open
mark14wu wants to merge 9 commits intomainfrom
fix/none-to-tensor-type-error
Open

[FIX] Patch libdevice stubs for interpreter/sanitizer mode#305
mark14wu wants to merge 9 commits intomainfrom
fix/none-to-tensor-type-error

Conversation

@mark14wu
Copy link
Copy Markdown
Collaborator

@mark14wu mark14wu commented Mar 2, 2026

Summary

  • Libdevice stub functions (e.g. tanh) return None in interpreter/sanitizer mode because Triton's _patch_lang does not patch the triton.language.extra.libdevice module, causing TypeError: cannot convert None of type <class 'NoneType'> to tensor when the result is passed to tl.store
  • Replace libdevice stubs with numpy-backed implementations that route through interpreter_builder.unary_op, handling both module attribute access (libdevice.tanh(x)) and direct imports (from ... import tanh)
  • Register tanh as a symbolic unary operation (UNARY_OPS + _UNARY_NUMPY_TO_SYM_OP) so the sanitizer can track it

Test plan

  • Added test_libdevice_tanh end-to-end test that exercises the exact reproducer pattern (load → tanh → store)
  • All 21 existing sanitizer tests continue to pass

mark14wu added 2 commits March 2, 2026 06:15
Libdevice stub functions (e.g. `tanh`) return None in interpreter mode
because Triton's _patch_lang does not patch the libdevice module. This
causes TypeError when the result is passed to tl.store.

Replace libdevice stubs with numpy-backed implementations that route
through interpreter_builder.unary_op, and register tanh as a symbolic
unary operation so the sanitizer can track it.
Resolve conflicts: keep libdevice patching (HEAD) + client_manager param (main),
and keep both libdevice and fake tensor test sections.
@github-actions
Copy link
Copy Markdown

github-actions bot commented Mar 4, 2026

Sanitizer Performance Benchmark

Benchmark main (min) PR (min) Change
gemm 0.183s 0.184s +0.5%
gemm_oob 0.191s 0.191s +0.2%
indirect_load 0.293s 0.293s +0.2%
nested_loop 0.372s 0.374s +0.7%
block_pointer_loop_advance 0.186s 0.188s +1.0%
liger_jsd 0.150s 0.151s +0.5%
flaggems_layernorm 0.457s 0.461s +0.9%
swiglu 0.184s 0.186s +1.4%
cross_entropy 0.172s 0.174s +0.8%
fused_linear_jsd 0.225s 0.228s +1.3%
Total 2.412s 2.430s +0.7%

Iterations: 1 warmup + 20 measured

…ore, and asin/acos support

Replace flat dict + monkey-patched restore with LibdeviceSpec registry,
arity-dispatching factory, and single tagged stack in _LangPatchScope.
Unsupported ops now raise NotImplementedError immediately instead of
silently returning None. Extend symbolic engine with asin/acos ops and
add UnarySymbolicExpr.concretize().
…c for builder ops, add rsqrt

Address PR review feedback: _LIBDEVICE_REGISTRY is now the true single source
of truth. UNARY_OPS, _NUMPY_OPS, and _UNARY_NUMPY_TO_SYM_OP in
symbolic_engine.py are derived from the registry instead of being maintained
as parallel hardcoded mappings.

LibdeviceSpec gains an optional builder_method field so ops like rsqrt that
use interpreter_builder methods directly (rather than numpy ufuncs) can be
expressed in the registry. rsqrt is added as the first builder-backed op.

Tests: restore verification for libdevice patching, rsqrt E2E (sanitizer +
numerical correctness), updated consistency and concretize tests.
Filter all registry-derived structures by spec.arity == 1 to prevent
future arity>1 specs from being misclassified as unary ops. Replace
NotImplementedError in _to_z3_impl() with a fresh opaque Z3 Int symbol
so transcendental ops don't crash the sanitizer when their results flow
into Z3-analyzed paths.

Add tests for Z3 fallback, arity-conditional consistency, alias restore,
and module-style unsupported op (libdevice.erf).
… caches

- Extract LibdeviceSpec and _LIBDEVICE_REGISTRY to core/libdevice_registry.py
  to break the reverse dependency (symbolic_engine.py → patch.py).
- Wrap triton_patch_lang + _patch_libdevice in try/except so a failure in
  _patch_libdevice rolls back all Triton lang state via scope.restore().
- Add _Z3_RANGE_BOUNDS to UnarySymbolicExpr: tanh/sin/cos bounded to [-1,1],
  exp/sqrt/rsqrt bounded to >=0, etc. Prevents unconstrained opaque symbols
  from degrading sanitizer analysis when transcendental ops flow into pointer
  arithmetic.
- Pre-build _INTERPRETER_FNS and _REGISTERED_SPECS at module level, cache
  unsupported-fn wrappers with @cache, skip _patch_libdevice entirely when
  the kernel doesn't reference libdevice.
- Add tests: patch rollback regression, Z3 range constraint solver checks,
  addptr-through-bounded-unary, multiple-alias restore, unsupported alias.
…ests

- Add erf to registry as numpy-backed op (dtype-preserving vectorized
  math.erf), since interpreter_builder.create_erf exists upstream.
- Remove blanket replacement of unregistered libdevice stubs with
  NotImplementedError — only patch registered ops. Unregistered stubs
  are left as-is, narrowing the patch surface to what we actually
  support.
- Remove _make_unsupported_libdevice_fn and the else branch in
  _patch_libdevice_aliases that replaced unregistered aliases.
- Add erf to e2e sanitizer op test, numerical correctness test, and
  unit concretize test.
- Add tests: unregistered alias left unchanged, same-module helper alias
  patched, default-arg capture NOT patched (documented limitation).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant