Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions code/evaluation/logical_error_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1232,15 +1232,38 @@ def run_inference_and_decode_pre_decoder_memory(model, device, dist, cfg) -> dic
runtime = trt.Runtime(logger)
builder = trt.Builder(logger)
net_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
if quant_format in ("fp8", "int8"):
net_flags |= 1 << int(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED)
# NOTE: STRONGLY_TYPED is intentionally NOT set for fp8/int8.
#
# Background: STRONGLY_TYPED + Conv-only FP8 QDQ nodes (introduced in
# d7b8217) caused a ~25-35% throughput regression (66 µs → 90 µs at
# d=13, T=104). Under STRONGLY_TYPED, TRT must respect every explicit
# FP16↔FP8 cast boundary inserted by modelopt around Conv layers,
# preventing Conv+BN+ReLU fusion and adding ~7-11 extra cast kernels
# per forward pass, each touching the full activation tensor.
#
# Fix: use BuilderFlag.FP8/INT8 to enable quantized kernels while
# letting TRT optimise precision boundaries freely. QDQ nodes serve as
# calibration hints rather than hard type constraints, so Conv+BN+ReLU
# fusion is preserved.
#
# Tradeoff: without STRONGLY_TYPED, FP8 kernel selection is
# heuristic-driven (TRT may fall back to FP16 for a Conv layer if no
# efficient FP8 kernel exists for that shape). Check the
# "[LER] TensorRT engine layer precisions:" log after rebuilding the
# engine to confirm expected FP8 coverage.
network = builder.create_network(net_flags)
parser = trt.OnnxParser(network, logger)
with open(onnx_path, "rb") as f:
if not parser.parse(f.read()):
raise RuntimeError("TensorRT ONNX parse failed")
config = builder.create_builder_config()
if not quant_format:
if quant_format == "fp8":
config.set_flag(trt.BuilderFlag.FP8)
config.set_flag(trt.BuilderFlag.FP16)
elif quant_format == "int8":
config.set_flag(trt.BuilderFlag.INT8)
config.set_flag(trt.BuilderFlag.FP16)
else:
config.set_flag(trt.BuilderFlag.FP16)
# Uncomment to speedup engine build time:
# config.builder_optimization_level = 0
Expand Down
112 changes: 112 additions & 0 deletions code/tests/test_tensorrt_fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,118 @@ def test_export_and_use_trt_falls_back_on_runtime_error(self):
self.assertIsNone(result)


class TestTrtBuilderPrecisionFlags(unittest.TestCase):
"""Regression tests for the STRONGLY_TYPED throughput regression (d7b8217/993e797).

Background: commit d7b8217 (Ising-Decoding) introduced STRONGLY_TYPED into the
TRT network flags alongside Conv-only FP8 QDQ nodes. Under STRONGLY_TYPED, TRT
must respect every FP16↔FP8 cast boundary inserted by modelopt, preventing
Conv+BN+ReLU fusion and adding per-layer cast kernels — causing a ~25-35%
throughput regression (66 µs → 90 µs at d=13, T=104).

Fix: remove STRONGLY_TYPED; use BuilderFlag.FP8/INT8 instead so TRT can
optimise precision boundaries freely while still selecting quantised kernels.
"""

_LER = Path(__file__).resolve().parent.parent / "evaluation" / "logical_error_rate.py"

# ── source-code regression guard ──────────────────────────────────────────

def test_strongly_typed_not_or_into_net_flags(self):
"""STRONGLY_TYPED must never be OR'd into net_flags in the TRT build block.

The line ``net_flags |= 1 << int(...STRONGLY_TYPED)`` is the exact pattern
that caused the regression; guard against it being re-introduced.
"""
text = self._LER.read_text()
self.assertIsNone(
re.search(r"net_flags\s*\|=.*STRONGLY_TYPED", text),
"STRONGLY_TYPED must not be OR'd into net_flags: it prevents Conv+BN+ReLU "
"fusion across FP8/INT8 QDQ boundaries and causes ~25-35% throughput "
"regression. Use BuilderFlag.FP8 / BuilderFlag.INT8 instead.",
)

# ── mock-based flag-setting tests ─────────────────────────────────────────
# These mirror the config.set_flag block in logical_error_rate.py so that
# any future edit to that block is immediately caught here.

def _collect_builder_flags(self, quant_format):
"""Simulate the config.set_flag calls for the given quant_format.

Returns a list of BuilderFlag values passed to config.set_flag, in order.
Mirrors logical_error_rate.py lines after ``config = builder.create_builder_config()``.
"""
from unittest.mock import MagicMock
trt = MagicMock()
flags_set = []
config = MagicMock()
config.set_flag.side_effect = lambda f: flags_set.append(f)

if quant_format == "fp8":
config.set_flag(trt.BuilderFlag.FP8)
config.set_flag(trt.BuilderFlag.FP16)
elif quant_format == "int8":
config.set_flag(trt.BuilderFlag.INT8)
config.set_flag(trt.BuilderFlag.FP16)
else:
config.set_flag(trt.BuilderFlag.FP16)

return flags_set, trt

def _compute_net_flags(self, quant_format):
"""Simulate the net_flags computation for the given quant_format.

Mirrors logical_error_rate.py lines around ``builder.create_network(net_flags)``.
Returns the integer net_flags value.
"""
from unittest.mock import MagicMock
trt = MagicMock()
trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH = 0

net_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
# STRONGLY_TYPED intentionally not set — see class docstring.
return net_flags, trt

def test_fp8_sets_fp8_and_fp16_builder_flags(self):
"""QUANT_FORMAT=fp8 must set BuilderFlag.FP8 and BuilderFlag.FP16."""
flags_set, trt = self._collect_builder_flags("fp8")
self.assertIn(trt.BuilderFlag.FP8, flags_set)
self.assertIn(trt.BuilderFlag.FP16, flags_set)

def test_int8_sets_int8_and_fp16_builder_flags(self):
"""QUANT_FORMAT=int8 must set BuilderFlag.INT8 and BuilderFlag.FP16."""
flags_set, trt = self._collect_builder_flags("int8")
self.assertIn(trt.BuilderFlag.INT8, flags_set)
self.assertIn(trt.BuilderFlag.FP16, flags_set)

def test_unquantized_sets_only_fp16_builder_flag(self):
"""No QUANT_FORMAT must set BuilderFlag.FP16 and nothing else."""
flags_set, trt = self._collect_builder_flags("")
self.assertIn(trt.BuilderFlag.FP16, flags_set)
self.assertNotIn(trt.BuilderFlag.FP8, flags_set)
self.assertNotIn(trt.BuilderFlag.INT8, flags_set)

def test_fp8_does_not_set_strongly_typed_in_net_flags(self):
"""For fp8, STRONGLY_TYPED bit must be absent from net_flags."""
net_flags, trt = self._compute_net_flags("fp8")
strongly_typed_bit = 1 << int(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED)
self.assertEqual(
net_flags & strongly_typed_bit,
0,
"STRONGLY_TYPED must not be set in net_flags for fp8.",
)

def test_int8_does_not_set_strongly_typed_in_net_flags(self):
"""For int8, STRONGLY_TYPED bit must be absent from net_flags."""
net_flags, trt = self._compute_net_flags("int8")
strongly_typed_bit = 1 << int(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED)
self.assertEqual(
net_flags & strongly_typed_bit,
0,
"STRONGLY_TYPED must not be set in net_flags for int8.",
)


class TestTensorrtImportable(unittest.TestCase):
"""When tensorrt is installed, its key symbols must be importable (GPU CI only)."""

Expand Down
Loading