Skip to content

Commit 993e797

Browse files
ivanbasovIgorBarattaclaude
authored
Add ONNX export, INT8/FP8 quantization, and SafeTensors inference loading (#14)
* feat(onnx): add QUANT_FORMAT int8/fp8 PTQ via modelopt.onnx - Add _collect_calibration_dets module-level helper that samples detector inputs from the inference dataloader for ONNX calibration - Parse QUANT_FORMAT env var (int8, fp8) in OnnxWorkflow export path; invalid values are ignored with a warning - Two-step export: always write FP32 ONNX first, then optionally apply modelopt.onnx.quantization.quantize() for the requested format - fp8 is fail-fast on error; int8 silently falls back to FP32 ONNX - Add QUANT_CALIB_SAMPLES env var (default 256) to control calibration sample count - Add test_onnx_quant_workflow.py: 13 CPU-only unit tests covering the calibration helper and QUANT_FORMAT routing logic * fix(onnx): re-derive engine_path from final onnx_path after quant fallback * review: fix run.py, temp file cleanup, YAPF, README ONNX section - run.py: remove emoji from print statements (style inconsistency) - run.py: remove no-op torch.compile(disable=True) calls - run.py: extract _resolve_dir() helper to replace 4 copies of the current_file/project_root path resolution pattern - run.py: replace bare torch.load/load_state_dict with _load_state_dict_from_pt() which handles model_state_dict/state_dict/bare-dict formats and strips the DDP "module." prefix — consistent with checkpoint_to_safetensors.py - tests: add addCleanup(os.unlink) for all NamedTemporaryFile paths - YAPF: reformat logical_error_rate.py and test_onnx_quant_workflow.py - README: add ONNX export and quantization section documenting ONNX_WORKFLOW modes, QUANT_FORMAT, QUANT_CALIB_SAMPLES Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * quantize only CNN layers * fix(ci): YAPF, move nvidia-modelopt to train reqs, add prerequisite tests - YAPF: reformat 3 long lines in logical_error_rate.py introduced by the "quantize only CNN layers" commit (d7b8217) - Move nvidia-modelopt[onnx] from requirements_public_inference.txt to requirements_public_train.txt; it is only needed for ONNX PTQ export (QUANT_FORMAT env var), not for pure inference, and has no Python 3.13 build — keeping it in inference reqs broke unit-tests/py3.13 in CI - Add python_version<"3.13" marker so the CI train matrix installs it on supported Python versions without failing on 3.13 - Add TestModeloptPrerequisite in test_onnx_quant_workflow.py: - asserts nvidia-modelopt is declared in requirements_public_train.txt - asserts it is absent from requirements_public_inference.txt - conditionally checks the import is resolvable when the package is present (skipped on Python 3.13+ and when not installed) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * feat(onnx): add onnxruntime INT8 fallback for Python 3.13+ nvidia-modelopt does not support Python 3.13+. Add a conditional backend dispatch so QUANT_FORMAT=int8 works on all supported Python versions: - Add _ort_quantize_int8() module-level helper that uses onnxruntime.quantization.quantize_static() with QDQ/QInt8 format and a CalibrationDataReader wrapping the pre-collected calib_dets array - In the quantization block, branch on sys.version_info >= (3, 13): - Python 3.13+: call _ort_quantize_int8(); raise immediately for FP8 (no viable 3.13-compatible FP8 PTQ library available) - Python <3.13: keep existing modelopt path unchanged - Add onnxruntime (python_version >= "3.13") to requirements_public_train.txt - Expand TestOrtQuantizeInt8 tests: - round-trip test (build tiny Gemm ONNX, quantize, validate) on 3.13+ - dispatch test verifying _ort_quantize_int8 is called on 3.13+ - FP8-on-3.13 raises RuntimeError - Expand TestModeloptPrerequisite: assert onnxruntime appears in train requirements and both quant packages are absent from inference requirements Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * fix(prereqs): document tensorrt as optional GPU dep, add fallback tests tensorrt is a heavy CUDA-only SDK (~500 MB) that cannot be pip-installed in CPU-only CI, so it is not added as an active pip requirement. Instead: - Add a comment block in requirements_public_inference.txt documenting tensorrt as an optional prerequisite for ONNX_WORKFLOW=2/3 paths, with the install command and a note about graceful fallback - Add test_tensorrt_fallback.py with three test classes: - TestTensorrtDocumented: asserts the requirements comment exists and tensorrt is NOT an active pip requirement - TestTensorrtFallback: verifies both TRT import sites (USE_ENGINE_ONLY and EXPORT_AND_USE_TRT) set trt_context=None on ImportError and do not propagate the exception to the caller - TestTensorrtImportable: checks key TRT symbols (Logger, Runtime, Builder, BuilderFlag, LayerInformationFormat) when tensorrt is installed; skipped silently on CPU-only environments Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * fix(onnx): use import-based dispatch for modelopt/ort; install modelopt on py3.13 nvidia-modelopt works on Python 3.13 when installed with --ignore-requires-python (confirmed by modelopt maintainers). - logical_error_rate.py: replace sys.version_info dispatch with an ImportError-based dispatch — try modelopt first (INT8+FP8), fall back to _ort_quantize_int8 only when modelopt is not importable; FP8 raises RuntimeError with the --ignore-requires-python install hint - check_python_compat.sh: after the main requirements install, re-install nvidia-modelopt[onnx] with --ignore-requires-python when MODE=train and Python >= 3.13, so GPU CI on 3.13 uses the full modelopt path - requirements_public_train.txt: add comment documenting the 3.13 install approach for manual setups - test_onnx_quant_workflow.py: - remove py3.13-specific skip from test_ort_quantize_int8_produces_output_file (now skips when onnxruntime is not installed, regardless of version) - replace test_ort_quantize_int8_dispatch_on_py313 with test_ort_quantize_int8_called_on_modelopt_import_error - replace test_fp8_raises_on_py313 with test_fp8_raises_on_modelopt_import_error - remove py3.13 version guard from test_modelopt_importable_when_installed - remove py3.13 version guard from test_ort_importable_when_installed Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * fix(test): pin ONNX IR version 8 in ort quantize test modelopt[onnx] pulls in onnxruntime-gpu~=1.22.0 as a dependency on all Python versions. Newer ONNX packages (1.19+) default model.ir_version to 12, but onnxruntime-gpu 1.22.0 only supports up to IR version 10, causing test_ort_quantize_int8_produces_output_file to fail on the GPU CI for py3.11, py3.12, and py3.13. Pin model.ir_version = 8 (the minimum required for opset 17) before saving the test model so the calibration InferenceSession succeeds with any onnxruntime version that supports IR ≤ 10. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * test(onnx): add end-to-end mq.quantize() tests for modelopt Previous coverage only verified that modelopt.onnx.quantization was importable. Add TestModeloptQuantize with two tests that actually call mq.quantize() on a real ONNX model: - test_mq_quantize_int8_produces_valid_onnx: verifies the output file is created and passes onnx.checker (confirms modelopt works at runtime, not just at import time — this is the key Python 3.13 regression check) - test_mq_quantize_int8_output_differs_from_fp32: verifies QDQ nodes were inserted (output graph has more nodes than the FP32 source) Both tests share a _build_tiny_model() helper that creates a minimal Gemm ONNX model with input "dets" and 16 calibration rows, matching the production calibration_data={"dets": calib_dets} call convention. model.ir_version is pinned to 8 for onnxruntime-gpu 1.22.0 compatibility. Tests are skipped when nvidia-modelopt is not installed. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * fix(test): use float32 calibration data in TestModeloptQuantize mq.quantize() runs an internal ONNX inference session to profile MatMul nodes; feeding uint8 calibration data to a float-input model caused InvalidArgument. Switch to np.random.randn(...).astype(float32). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * fix(trt): raise RuntimeError when tensorrt missing for ONNX_WORKFLOW=2/3 Previously both TRT import sites caught ImportError inside a broad `except Exception` block and silently fell back to PyTorch with a print. This masked misconfiguration: the user explicitly selected ONNX_WORKFLOW=2 or 3, so a missing tensorrt install is always a hard error. Changes: - USE_ENGINE_ONLY (workflow=3): ImportError now raises RuntimeError with install hint; other TRT errors (bad engine file) still fall back gracefully. - EXPORT_AND_USE_TRT (workflow=2): same split. - test_tensorrt_fallback.py: replace the old "falls back on ImportError" tests with "raises RuntimeError on ImportError" tests; add chained cause check and non-import fallback tests. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * fix(test): fix ORT calibration for ort quantize test ORT's MinMaxCalibrater augments the model to expose intermediate tensors for calibration, but graph *inputs* are not included in the augmented outputs. When the test model had dets->Gemm directly, ORT never collected calibration stats for 'dets', causing: ValueError: Quantization parameters are not specified for param dets. Fix: insert a Relu node (dets -> Relu -> dets_relu -> Gemm) so the Gemm input is an intermediate tensor that gets calibrated. Also switch the calibration array to float32 (consistent with model dtype) and add rewind() to _DetCalibReader in production code for calibration methods that make multiple passes. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * fix(test): skip ort quantize output test when modelopt is installed _ort_quantize_int8 is only invoked when modelopt is absent. When modelopt IS installed its mq.quantize() call leaves ORT's execution- provider state dirty (failed TRT EP init), causing the calibration InferenceSession to run silently without producing stats, which makes quantize_static raise: ValueError: Quantization parameters are not specified for param dets. The test is meaningless in that environment anyway — if modelopt is present the ort path is never taken. Skip when modelopt is importable. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * review: address PR #14 review comments - README: ONNX_WORKFLOW=1 runs PyTorch inference after export (not stop) (bmhowe23 suggestion) - LER: cast calib_dets to float32 before passing to mq.quantize(); _collect_calibration_dets returns uint8 but modelopt expects float (sacpis: bug report on line 1077) - LER: use Path.with_suffix('.engine') instead of str.replace (sacpis nit on line 1104) - LER: add pathlib.Path import - test: remove spurious @skipUnless from _build_tiny_model helper; it is not a test method and the decorator has no effect (sacpis nit on line 299) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * refactor: extract _parse_quant_format() helper from LER Move the QUANT_FORMAT env-var read/validate/warn block into a module-level helper so the test can call the real production logic instead of re-implementing it. - Add _parse_quant_format(rank=0) -> str in logical_error_rate.py - Replace inline parsing block in run_inference_and_decode with a single _parse_quant_format(rank=dist.rank) call - Import _parse_quant_format in test_onnx_quant_workflow.py and simplify _run_quant_block to delegate to it Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * fix: guard against num_obs < 1 in _collect_calibration_dets Python's [:, :-0] is equivalent to [:, :0] and silently returns an empty tensor rather than the full row. Add an explicit check so the caller gets a clear ValueError instead of a confusing width-mismatch error downstream. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * docs: fix tensorrt comment — missing TRT now raises RuntimeError The comment said "Absent at runtime causes graceful fallback to the PyTorch path", but since the TRT ImportError fix (ae0f3b1) both ONNX_WORKFLOW=2 and =3 raise RuntimeError instead of falling back. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> --------- Co-authored-by: Igor Baratta <ialmeidabara@nvidia.com> Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 80a64db commit 993e797

9 files changed

Lines changed: 951 additions & 82 deletions

README.md

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,48 @@ WORKFLOW=inference bash code/scripts/local_run.sh
131131
```
132132

133133
`MODEL_ID` is the public model identifier (1–5); see `model/registry.py` for the mapping.
134+
The pre-trained public models use `--model-id 1` (R=9) and `--model-id 4` (R=13).
135+
136+
### ONNX export and quantization (optional, post-training)
137+
138+
After training (or starting from the shipped `.safetensors` files), you can export the model to
139+
ONNX and optionally apply INT8 or FP8 post-training quantization for deployment.
140+
141+
Set the `ONNX_WORKFLOW` and (optionally) `QUANT_FORMAT` environment variables before running
142+
inference with `local_run.sh`:
143+
144+
| `ONNX_WORKFLOW` | Behavior |
145+
|---|---|
146+
| `0` (default) | PyTorch inference only, no ONNX export |
147+
| `1` | Export ONNX model and run inference with PyTorch |
148+
| `2` | Export ONNX model and run inference via TensorRT |
149+
| `3` | Load a pre-existing TensorRT engine file and run inference |
150+
151+
```bash
152+
# Export ONNX only (no TensorRT)
153+
ONNX_WORKFLOW=1 WORKFLOW=inference bash code/scripts/local_run.sh
154+
155+
# Export ONNX + apply INT8 quantization + run TensorRT inference
156+
ONNX_WORKFLOW=2 QUANT_FORMAT=int8 WORKFLOW=inference bash code/scripts/local_run.sh
157+
158+
# Export ONNX + apply FP8 quantization + run TensorRT inference
159+
ONNX_WORKFLOW=2 QUANT_FORMAT=fp8 WORKFLOW=inference bash code/scripts/local_run.sh
160+
161+
# Use a pre-built TensorRT engine (skip export)
162+
ONNX_WORKFLOW=3 WORKFLOW=inference bash code/scripts/local_run.sh
163+
```
164+
165+
**Quantization variables:**
166+
167+
| Variable | Default | Description |
168+
|---|---|---|
169+
| `QUANT_FORMAT` | unset | `int8` or `fp8`. Unset means no quantization (FP32 ONNX). |
170+
| `QUANT_CALIB_SAMPLES` | `256` | Calibration samples for INT8/FP8 post-training quantization. |
171+
172+
Notes:
173+
- TensorRT workflows (`ONNX_WORKFLOW=2` or `3`) require `tensorrt` and `modelopt`.
174+
- FP8 quantization failure is fatal. INT8 failure falls back to the FP32 ONNX model silently.
175+
- ONNX and engine files are written to the current working directory.
134176

135177
### GPU selection
136178

code/evaluation/logical_error_rate.py

Lines changed: 214 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import sys
1616
import os
1717
from enum import IntEnum
18+
from pathlib import Path
1819
from typing import Optional
1920

2021

@@ -50,6 +51,110 @@ def _detect_shm_bytes() -> Optional[int]:
5051
return None
5152

5253

54+
def _parse_quant_format(rank: int = 0) -> str:
55+
"""Read and validate the QUANT_FORMAT environment variable.
56+
57+
Returns the validated format string ('int8' or 'fp8'), or '' if unset or invalid.
58+
Prints a warning on rank 0 when the value is set but not recognised.
59+
"""
60+
quant_format = os.environ.get("QUANT_FORMAT", "").strip().lower()
61+
if quant_format and quant_format not in ("int8", "fp8"):
62+
if rank == 0:
63+
print(f"[LER] Invalid QUANT_FORMAT='{quant_format}', ignoring. Supported: int8, fp8")
64+
quant_format = ""
65+
return quant_format
66+
67+
68+
def _collect_calibration_dets(
69+
test_dataloader,
70+
num_obs: int,
71+
target_samples: int,
72+
expected_width: int,
73+
) -> "np.ndarray":
74+
"""Collect representative detector inputs from a dataloader for ONNX calibration.
75+
76+
Args:
77+
test_dataloader: DataLoader yielding batches with a "dets_and_obs" key.
78+
num_obs: Number of observable columns at the end of dets_and_obs to strip.
79+
target_samples: Desired number of calibration rows.
80+
expected_width: Expected number of detector columns after stripping observables.
81+
82+
Returns:
83+
np.ndarray of shape (target_samples, expected_width), dtype uint8.
84+
"""
85+
if num_obs < 1:
86+
raise ValueError(
87+
f"num_obs must be >= 1, got {num_obs}. "
88+
"dets_and_obs[:, :-0] would silently return an empty tensor."
89+
)
90+
target_samples = max(int(target_samples), 1)
91+
chunks = []
92+
collected = 0
93+
for calib_batch in test_dataloader:
94+
dets_and_obs_batch = calib_batch["dets_and_obs"]
95+
dets_only_batch = dets_and_obs_batch[:, :-num_obs].to(torch.uint8).contiguous()
96+
if int(dets_only_batch.shape[1]) != int(expected_width):
97+
raise RuntimeError(
98+
f"Calibration det width {dets_only_batch.shape[1]} != expected {expected_width}"
99+
)
100+
if dets_only_batch.numel() == 0:
101+
continue
102+
take = min(target_samples - collected, int(dets_only_batch.shape[0]))
103+
if take > 0:
104+
chunks.append(dets_only_batch[:take].cpu().numpy())
105+
collected += take
106+
if collected >= target_samples:
107+
break
108+
if not chunks:
109+
raise RuntimeError("No calibration samples could be collected from test_dataloader.")
110+
calib = np.concatenate(chunks, axis=0)
111+
if calib.shape[0] < target_samples:
112+
reps = int(np.ceil(target_samples / float(calib.shape[0])))
113+
calib = np.tile(calib, (reps, 1))[:target_samples]
114+
return np.ascontiguousarray(calib, dtype=np.uint8)
115+
116+
117+
def _ort_quantize_int8(fp32_onnx_path: str, output_path: str, calib_dets: "np.ndarray") -> None:
118+
"""INT8 static quantization via onnxruntime.quantization (Python 3.13+ fallback).
119+
120+
Used when nvidia-modelopt is unavailable (it does not support Python 3.13+).
121+
Quantises all Conv and Gemm nodes with QInt8 weights and activations using
122+
QDQ format, which is compatible with TensorRT INT8 parsing.
123+
124+
Args:
125+
fp32_onnx_path: Path to the source FP32 ONNX model.
126+
output_path: Destination path for the quantized ONNX model.
127+
calib_dets: Calibration data array of shape (N, det_cols), dtype uint8.
128+
"""
129+
from onnxruntime.quantization import (
130+
CalibrationDataReader,
131+
QuantFormat,
132+
QuantType,
133+
quantize_static,
134+
)
135+
136+
class _DetCalibReader(CalibrationDataReader):
137+
138+
def __init__(self, data):
139+
self._rows = [{"dets": data[i:i + 1].astype("float32")} for i in range(len(data))]
140+
self._iter = iter(self._rows)
141+
142+
def get_next(self):
143+
return next(self._iter, None)
144+
145+
def rewind(self):
146+
self._iter = iter(self._rows)
147+
148+
quantize_static(
149+
fp32_onnx_path,
150+
output_path,
151+
_DetCalibReader(calib_dets),
152+
quant_format=QuantFormat.QDQ,
153+
activation_type=QuantType.QInt8,
154+
weight_type=QuantType.QInt8,
155+
)
156+
157+
53158
def _time_single_shot_latency_stim(
54159
matcher,
55160
baseline_syndromes: np.ndarray,
@@ -878,8 +983,15 @@ def run_inference_and_decode_pre_decoder_memory(model, device, dist, cfg) -> dic
878983
if dist.rank == 0:
879984
print(f"[LER] Invalid ONNX_WORKFLOW='{_workflow_raw}', using 0 (torch only).")
880985
trt_context = None # (context, engine, device_id) when using TensorRT
881-
onnx_path = os.path.join(os.getcwd(), f"predecoder_memory_d{D}_T{T_original}_{basis}.onnx")
882-
engine_path = os.path.join(os.getcwd(), f"predecoder_memory_d{D}_T{T_original}_{basis}.engine")
986+
# --- QUANT_FORMAT: optional quantization (int8, fp8) applied to ONNX after FP32 export ---
987+
quant_format = _parse_quant_format(rank=dist.rank)
988+
quant_suffix = f"_{quant_format}" if quant_format else ""
989+
onnx_path = os.path.join(
990+
os.getcwd(), f"predecoder_memory_d{D}_T{T_original}_{basis}{quant_suffix}.onnx"
991+
)
992+
engine_path = os.path.join(
993+
os.getcwd(), f"predecoder_memory_d{D}_T{T_original}_{basis}{quant_suffix}.engine"
994+
)
883995
half = (D * D - 1) // 2
884996
example_shape = (batch_size_original, 2 * T_original * half)
885997

@@ -901,6 +1013,11 @@ def run_inference_and_decode_pre_decoder_memory(model, device, dist, cfg) -> dic
9011013
print(
9021014
f"[LER] TensorRT engine loaded from {engine_path} in {t_load_end - t_load_start:.3f}s"
9031015
)
1016+
except ImportError as e:
1017+
raise RuntimeError(
1018+
"[LER] ONNX_WORKFLOW=3 (USE_ENGINE_ONLY) requires tensorrt to be installed. "
1019+
"Install with: pip install tensorrt"
1020+
) from e
9041021
except Exception as e:
9051022
if dist.rank == 0:
9061023
print(f"[LER] TensorRT engine load failed: {e}; falling back to PyTorch.")
@@ -915,10 +1032,16 @@ def run_inference_and_decode_pre_decoder_memory(model, device, dist, cfg) -> dic
9151032
if dist.rank == 0:
9161033
try:
9171034
example_dets = torch.randint(0, 2, example_shape, dtype=torch.uint8, device=device)
1035+
1036+
# Step 1: Always export FP32 ONNX first
1037+
fp32_onnx_path = (
1038+
onnx_path
1039+
if not quant_format else onnx_path.replace(f"_{quant_format}.onnx", ".onnx")
1040+
)
9181041
torch.onnx.export(
9191042
pipeline_module,
9201043
example_dets,
921-
onnx_path,
1044+
fp32_onnx_path,
9221045
opset_version=18,
9231046
external_data=False,
9241047
input_names=["dets"],
@@ -934,27 +1057,84 @@ def run_inference_and_decode_pre_decoder_memory(model, device, dist, cfg) -> dic
9341057
do_constant_folding=True,
9351058
dynamo=False,
9361059
)
937-
print(f"[LER] Exported ONNX: {onnx_path}")
1060+
print(f"[LER] Exported FP32 ONNX: {fp32_onnx_path}")
1061+
1062+
# Step 2: If QUANT_FORMAT is set, apply ONNX-level quantization.
1063+
# Backend: nvidia-modelopt on Python <3.13; onnxruntime on Python 3.13+
1064+
# (nvidia-modelopt does not support Python 3.13+).
1065+
if quant_format:
1066+
try:
1067+
num_obs_for_calib = circuit.num_observables
1068+
calib_num_samples = int(os.environ.get("QUANT_CALIB_SAMPLES", "256"))
1069+
print(
1070+
f"[LER] Collecting {calib_num_samples} calibration samples "
1071+
"from inference dataloader..."
1072+
)
1073+
calib_dets = _collect_calibration_dets(
1074+
test_dataloader, num_obs_for_calib, calib_num_samples, example_shape[1]
1075+
)
1076+
1077+
print(
1078+
f"[LER] Applying {quant_format.upper()} quantization to ONNX model..."
1079+
)
1080+
# Prefer modelopt (INT8+FP8); fall back to onnxruntime (INT8 only)
1081+
# when modelopt is not installed. On Python 3.13+ modelopt can
1082+
# be installed with: pip install nvidia-modelopt[onnx]
1083+
# --ignore-requires-python
1084+
try:
1085+
import modelopt.onnx.quantization as mq
1086+
quant_kwargs = {}
1087+
if quant_format == "fp8":
1088+
quant_kwargs["op_types_to_quantize"] = ["Conv"]
1089+
quant_kwargs["high_precision_dtype"] = "fp16"
1090+
mq.quantize(
1091+
onnx_path=fp32_onnx_path,
1092+
quantize_mode=quant_format,
1093+
calibration_data={"dets": calib_dets.astype("float32")},
1094+
output_path=onnx_path,
1095+
**quant_kwargs,
1096+
)
1097+
except ImportError:
1098+
if quant_format == "fp8":
1099+
raise RuntimeError(
1100+
"[LER] FP8 quantization requires nvidia-modelopt. "
1101+
"Install with: pip install 'nvidia-modelopt[onnx]'"
1102+
" --ignore-requires-python"
1103+
)
1104+
_ort_quantize_int8(fp32_onnx_path, onnx_path, calib_dets)
1105+
print(f"[LER] Exported quantized ONNX: {onnx_path}")
1106+
except Exception as e:
1107+
if quant_format == "fp8":
1108+
raise RuntimeError(
1109+
f"[LER] FP8 ONNX quantization failed (fail-fast): {e}"
1110+
) from e
1111+
print(f"[LER] ONNX quantization failed: {e}; using FP32 ONNX.")
1112+
onnx_path = fp32_onnx_path
9381113
except Exception as e:
9391114
if dist.rank == 0:
9401115
print(f"[LER] ONNX export failed: {e}; falling back to PyTorch.")
9411116
onnx_workflow = OnnxWorkflow.TORCH_ONLY
9421117
if dist.world_size > 1:
9431118
torch.distributed.barrier()
1119+
# Re-derive engine_path from the final onnx_path (may have changed on quant fallback)
1120+
engine_path = str(Path(onnx_path).with_suffix(".engine"))
9441121
if onnx_workflow == OnnxWorkflow.EXPORT_AND_USE_TRT and device.type == "cuda":
9451122
try:
9461123
import tensorrt as trt
9471124
logger = trt.Logger(trt.Logger.WARNING)
9481125
runtime = trt.Runtime(logger)
9491126
builder = trt.Builder(logger)
950-
network = builder.create_network(
951-
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
952-
)
1127+
net_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
1128+
if quant_format in ("fp8", "int8"):
1129+
net_flags |= 1 << int(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED)
1130+
network = builder.create_network(net_flags)
9531131
parser = trt.OnnxParser(network, logger)
9541132
with open(onnx_path, "rb") as f:
9551133
if not parser.parse(f.read()):
9561134
raise RuntimeError("TensorRT ONNX parse failed")
9571135
config = builder.create_builder_config()
1136+
if not quant_format:
1137+
config.set_flag(trt.BuilderFlag.FP16)
9581138
# Uncomment this out to speedup engine build time.
9591139
# config.builder_optimization_level = 0
9601140
in_name = "dets"
@@ -989,6 +1169,29 @@ def run_inference_and_decode_pre_decoder_memory(model, device, dist, cfg) -> dic
9891169
)
9901170
if dist.rank == 0:
9911171
print(f"[LER] TensorRT engine built from {onnx_path}")
1172+
inspector = engine.create_engine_inspector()
1173+
if inspector is not None:
1174+
layer_info = inspector.get_engine_information(
1175+
trt.LayerInformationFormat.JSON
1176+
)
1177+
import json as _json
1178+
try:
1179+
info = _json.loads(layer_info)
1180+
layers = info.get("Layers", [])
1181+
precision_counts: dict = {}
1182+
for layer in layers:
1183+
prec = layer.get(
1184+
"LayerPrecision", layer.get("Precision", "unknown")
1185+
)
1186+
precision_counts[prec] = precision_counts.get(prec, 0) + 1
1187+
print(f"[LER] TensorRT engine layer precisions: {precision_counts}")
1188+
except Exception:
1189+
pass
1190+
except ImportError as e:
1191+
raise RuntimeError(
1192+
"[LER] ONNX_WORKFLOW=2 (EXPORT_AND_USE_TRT) requires tensorrt to be installed. "
1193+
"Install with: pip install tensorrt"
1194+
) from e
9921195
except Exception as e:
9931196
if dist.rank == 0:
9941197
print(f"[LER] TensorRT build/load failed: {e}; falling back to PyTorch.")
@@ -1030,11 +1233,13 @@ def run_inference_and_decode_pre_decoder_memory(model, device, dist, cfg) -> dic
10301233
int(dets.data_ptr()),
10311234
int(L_and_residual_dets.data_ptr()),
10321235
]
1033-
print(f"[LER] Executing TensorRT context with bindings...")
10341236
t_execute_start = time.perf_counter()
10351237
context.execute_v2(bindings=bindings)
10361238
t_execute_end = time.perf_counter()
1037-
print(f"[LER] TensorRT execution completed in {t_execute_end - t_execute_start:.3f}s")
1239+
if batch_idx == 0 and dist.rank == 0:
1240+
print(
1241+
f"[LER] TensorRT first batch executed in {t_execute_end - t_execute_start:.3f}s"
1242+
)
10381243
else:
10391244
L_and_residual_dets = pipeline_module(dets_only)
10401245
pre_L = L_and_residual_dets[:, 0].to(torch.int32)

code/requirements_public_inference.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,7 @@ stim
1919
pymatching
2020
matplotlib
2121
safetensors>=0.4.0
22+
# Optional GPU-only prerequisite (not pip-installed here due to size and CUDA dependency):
23+
# tensorrt -- required for ONNX_WORKFLOW=2 (EXPORT_AND_USE_TRT) and ONNX_WORKFLOW=3
24+
# (USE_ENGINE_ONLY). Install via: pip install tensorrt
25+
# Absent at runtime raises RuntimeError (workflows 2/3 require TensorRT).

code/requirements_public_train.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,10 @@
1414
-r requirements_public_inference.txt
1515
tensorboard
1616
torchinfo
17+
# ONNX quantization (INT8/FP8 via QUANT_FORMAT).
18+
# nvidia-modelopt[onnx] officially caps at Python <3.13 but works on 3.13 in practice.
19+
# check_python_compat.sh installs it with --ignore-requires-python on Python 3.13+.
20+
# For manual installs on Python 3.13+: pip install nvidia-modelopt[onnx] --ignore-requires-python
21+
# onnxruntime is the INT8-only fallback when modelopt is not importable.
22+
nvidia-modelopt[onnx]; python_version < "3.13"
23+
onnxruntime; python_version >= "3.13"

code/scripts/check_python_compat.sh

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,17 @@ if [[ -n "${TORCH_WHL_INDEX}" || -n "${TORCH_CUDA}" ]]; then
7272
fi
7373

7474
pip install -r "${REQ_FILE}"
75+
76+
# nvidia-modelopt[onnx] caps its Python requirement at <3.13, but works on 3.13
77+
# in practice. Re-install with --ignore-requires-python so train CI on 3.13+
78+
# uses modelopt (INT8+FP8) rather than the onnxruntime-only fallback.
79+
if [[ "${MODE}" == "train" ]]; then
80+
if python -c "import sys; exit(0 if sys.version_info >= (3, 13) else 1)" 2>/dev/null; then
81+
echo "Python 3.13+: installing nvidia-modelopt[onnx] --ignore-requires-python"
82+
pip install "nvidia-modelopt[onnx]" --ignore-requires-python
83+
fi
84+
fi
85+
7586
if [[ -n "${EXTRA_PKGS}" ]]; then
7687
pip install ${EXTRA_PKGS}
7788
fi

0 commit comments

Comments
 (0)