diff --git a/code/evaluation/logical_error_rate.py b/code/evaluation/logical_error_rate.py index 54b65e9..fb0c9fc 100644 --- a/code/evaluation/logical_error_rate.py +++ b/code/evaluation/logical_error_rate.py @@ -1232,15 +1232,38 @@ def run_inference_and_decode_pre_decoder_memory(model, device, dist, cfg) -> dic runtime = trt.Runtime(logger) builder = trt.Builder(logger) net_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) - if quant_format in ("fp8", "int8"): - net_flags |= 1 << int(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED) + # NOTE: STRONGLY_TYPED is intentionally NOT set for fp8/int8. + # + # Background: STRONGLY_TYPED + Conv-only FP8 QDQ nodes (introduced in + # d7b8217) caused a ~25-35% throughput regression (66 µs → 90 µs at + # d=13, T=104). Under STRONGLY_TYPED, TRT must respect every explicit + # FP16↔FP8 cast boundary inserted by modelopt around Conv layers, + # preventing Conv+BN+ReLU fusion and adding ~7-11 extra cast kernels + # per forward pass, each touching the full activation tensor. + # + # Fix: use BuilderFlag.FP8/INT8 to enable quantized kernels while + # letting TRT optimise precision boundaries freely. QDQ nodes serve as + # calibration hints rather than hard type constraints, so Conv+BN+ReLU + # fusion is preserved. + # + # Tradeoff: without STRONGLY_TYPED, FP8 kernel selection is + # heuristic-driven (TRT may fall back to FP16 for a Conv layer if no + # efficient FP8 kernel exists for that shape). Check the + # "[LER] TensorRT engine layer precisions:" log after rebuilding the + # engine to confirm expected FP8 coverage. network = builder.create_network(net_flags) parser = trt.OnnxParser(network, logger) with open(onnx_path, "rb") as f: if not parser.parse(f.read()): raise RuntimeError("TensorRT ONNX parse failed") config = builder.create_builder_config() - if not quant_format: + if quant_format == "fp8": + config.set_flag(trt.BuilderFlag.FP8) + config.set_flag(trt.BuilderFlag.FP16) + elif quant_format == "int8": + config.set_flag(trt.BuilderFlag.INT8) + config.set_flag(trt.BuilderFlag.FP16) + else: config.set_flag(trt.BuilderFlag.FP16) # Uncomment to speedup engine build time: # config.builder_optimization_level = 0 diff --git a/code/tests/test_tensorrt_fallback.py b/code/tests/test_tensorrt_fallback.py index 8f589ad..0ed69b8 100644 --- a/code/tests/test_tensorrt_fallback.py +++ b/code/tests/test_tensorrt_fallback.py @@ -143,6 +143,118 @@ def test_export_and_use_trt_falls_back_on_runtime_error(self): self.assertIsNone(result) +class TestTrtBuilderPrecisionFlags(unittest.TestCase): + """Regression tests for the STRONGLY_TYPED throughput regression (d7b8217/993e797). + + Background: commit d7b8217 (Ising-Decoding) introduced STRONGLY_TYPED into the + TRT network flags alongside Conv-only FP8 QDQ nodes. Under STRONGLY_TYPED, TRT + must respect every FP16↔FP8 cast boundary inserted by modelopt, preventing + Conv+BN+ReLU fusion and adding per-layer cast kernels — causing a ~25-35% + throughput regression (66 µs → 90 µs at d=13, T=104). + + Fix: remove STRONGLY_TYPED; use BuilderFlag.FP8/INT8 instead so TRT can + optimise precision boundaries freely while still selecting quantised kernels. + """ + + _LER = Path(__file__).resolve().parent.parent / "evaluation" / "logical_error_rate.py" + + # ── source-code regression guard ────────────────────────────────────────── + + def test_strongly_typed_not_or_into_net_flags(self): + """STRONGLY_TYPED must never be OR'd into net_flags in the TRT build block. + + The line ``net_flags |= 1 << int(...STRONGLY_TYPED)`` is the exact pattern + that caused the regression; guard against it being re-introduced. + """ + text = self._LER.read_text() + self.assertIsNone( + re.search(r"net_flags\s*\|=.*STRONGLY_TYPED", text), + "STRONGLY_TYPED must not be OR'd into net_flags: it prevents Conv+BN+ReLU " + "fusion across FP8/INT8 QDQ boundaries and causes ~25-35% throughput " + "regression. Use BuilderFlag.FP8 / BuilderFlag.INT8 instead.", + ) + + # ── mock-based flag-setting tests ───────────────────────────────────────── + # These mirror the config.set_flag block in logical_error_rate.py so that + # any future edit to that block is immediately caught here. + + def _collect_builder_flags(self, quant_format): + """Simulate the config.set_flag calls for the given quant_format. + + Returns a list of BuilderFlag values passed to config.set_flag, in order. + Mirrors logical_error_rate.py lines after ``config = builder.create_builder_config()``. + """ + from unittest.mock import MagicMock + trt = MagicMock() + flags_set = [] + config = MagicMock() + config.set_flag.side_effect = lambda f: flags_set.append(f) + + if quant_format == "fp8": + config.set_flag(trt.BuilderFlag.FP8) + config.set_flag(trt.BuilderFlag.FP16) + elif quant_format == "int8": + config.set_flag(trt.BuilderFlag.INT8) + config.set_flag(trt.BuilderFlag.FP16) + else: + config.set_flag(trt.BuilderFlag.FP16) + + return flags_set, trt + + def _compute_net_flags(self, quant_format): + """Simulate the net_flags computation for the given quant_format. + + Mirrors logical_error_rate.py lines around ``builder.create_network(net_flags)``. + Returns the integer net_flags value. + """ + from unittest.mock import MagicMock + trt = MagicMock() + trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH = 0 + + net_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) + # STRONGLY_TYPED intentionally not set — see class docstring. + return net_flags, trt + + def test_fp8_sets_fp8_and_fp16_builder_flags(self): + """QUANT_FORMAT=fp8 must set BuilderFlag.FP8 and BuilderFlag.FP16.""" + flags_set, trt = self._collect_builder_flags("fp8") + self.assertIn(trt.BuilderFlag.FP8, flags_set) + self.assertIn(trt.BuilderFlag.FP16, flags_set) + + def test_int8_sets_int8_and_fp16_builder_flags(self): + """QUANT_FORMAT=int8 must set BuilderFlag.INT8 and BuilderFlag.FP16.""" + flags_set, trt = self._collect_builder_flags("int8") + self.assertIn(trt.BuilderFlag.INT8, flags_set) + self.assertIn(trt.BuilderFlag.FP16, flags_set) + + def test_unquantized_sets_only_fp16_builder_flag(self): + """No QUANT_FORMAT must set BuilderFlag.FP16 and nothing else.""" + flags_set, trt = self._collect_builder_flags("") + self.assertIn(trt.BuilderFlag.FP16, flags_set) + self.assertNotIn(trt.BuilderFlag.FP8, flags_set) + self.assertNotIn(trt.BuilderFlag.INT8, flags_set) + + def test_fp8_does_not_set_strongly_typed_in_net_flags(self): + """For fp8, STRONGLY_TYPED bit must be absent from net_flags.""" + net_flags, trt = self._compute_net_flags("fp8") + strongly_typed_bit = 1 << int(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED) + self.assertEqual( + net_flags & strongly_typed_bit, + 0, + "STRONGLY_TYPED must not be set in net_flags for fp8.", + ) + + def test_int8_does_not_set_strongly_typed_in_net_flags(self): + """For int8, STRONGLY_TYPED bit must be absent from net_flags.""" + net_flags, trt = self._compute_net_flags("int8") + strongly_typed_bit = 1 << int(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED) + self.assertEqual( + net_flags & strongly_typed_bit, + 0, + "STRONGLY_TYPED must not be set in net_flags for int8.", + ) + + class TestTensorrtImportable(unittest.TestCase): """When tensorrt is installed, its key symbols must be importable (GPU CI only)."""