Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
4c96e6d
feat(onnx): add QUANT_FORMAT int8/fp8 PTQ via modelopt.onnx
IgorBaratta Mar 11, 2026
dd82393
fix(onnx): re-derive engine_path from final onnx_path after quant fal…
IgorBaratta Mar 11, 2026
65c9222
review: fix run.py, temp file cleanup, YAPF, README ONNX section
ivanbasov Mar 11, 2026
d7b8217
quantize only CNN layers
IgorBaratta Mar 13, 2026
1ce0713
Merge branch 'igor/onnx_quantization' of github.com:NVIDIA/quantum-pr…
IgorBaratta Mar 13, 2026
1537eb2
fix(ci): YAPF, move nvidia-modelopt to train reqs, add prerequisite t…
ivanbasov Mar 16, 2026
5a3f266
feat(onnx): add onnxruntime INT8 fallback for Python 3.13+
ivanbasov Mar 16, 2026
5323547
fix(prereqs): document tensorrt as optional GPU dep, add fallback tests
ivanbasov Mar 16, 2026
e24274c
fix(onnx): use import-based dispatch for modelopt/ort; install modelo…
ivanbasov Mar 16, 2026
af466ea
fix(test): pin ONNX IR version 8 in ort quantize test
ivanbasov Mar 16, 2026
d111c65
test(onnx): add end-to-end mq.quantize() tests for modelopt
ivanbasov Mar 16, 2026
8155fdd
fix(test): use float32 calibration data in TestModeloptQuantize
ivanbasov Mar 16, 2026
ae0f3b1
fix(trt): raise RuntimeError when tensorrt missing for ONNX_WORKFLOW=2/3
ivanbasov Mar 16, 2026
795e58e
fix(test): fix ORT calibration for ort quantize test
ivanbasov Mar 16, 2026
3b779ef
fix(test): skip ort quantize output test when modelopt is installed
ivanbasov Mar 16, 2026
45891e2
review: address PR #14 review comments
ivanbasov Mar 19, 2026
9c726a7
refactor: extract _parse_quant_format() helper from LER
ivanbasov Mar 19, 2026
e53e53c
fix: guard against num_obs < 1 in _collect_calibration_dets
ivanbasov Mar 19, 2026
ba4efa3
docs: fix tensorrt comment — missing TRT now raises RuntimeError
ivanbasov Mar 19, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,48 @@ WORKFLOW=inference bash code/scripts/local_run.sh
```

`MODEL_ID` is the public model identifier (1–5); see `model/registry.py` for the mapping.
The pre-trained public models use `--model-id 1` (R=9) and `--model-id 4` (R=13).

### ONNX export and quantization (optional, post-training)

After training (or starting from the shipped `.safetensors` files), you can export the model to
ONNX and optionally apply INT8 or FP8 post-training quantization for deployment.

Set the `ONNX_WORKFLOW` and (optionally) `QUANT_FORMAT` environment variables before running
inference with `local_run.sh`:

| `ONNX_WORKFLOW` | Behavior |
|---|---|
| `0` (default) | PyTorch inference only, no ONNX export |
| `1` | Export ONNX model and run inference with PyTorch |
| `2` | Export ONNX model and run inference via TensorRT |
| `3` | Load a pre-existing TensorRT engine file and run inference |

```bash
# Export ONNX only (no TensorRT)
ONNX_WORKFLOW=1 WORKFLOW=inference bash code/scripts/local_run.sh

# Export ONNX + apply INT8 quantization + run TensorRT inference
ONNX_WORKFLOW=2 QUANT_FORMAT=int8 WORKFLOW=inference bash code/scripts/local_run.sh

# Export ONNX + apply FP8 quantization + run TensorRT inference
ONNX_WORKFLOW=2 QUANT_FORMAT=fp8 WORKFLOW=inference bash code/scripts/local_run.sh

# Use a pre-built TensorRT engine (skip export)
ONNX_WORKFLOW=3 WORKFLOW=inference bash code/scripts/local_run.sh
```

**Quantization variables:**

| Variable | Default | Description |
|---|---|---|
| `QUANT_FORMAT` | unset | `int8` or `fp8`. Unset means no quantization (FP32 ONNX). |
| `QUANT_CALIB_SAMPLES` | `256` | Calibration samples for INT8/FP8 post-training quantization. |

Notes:
- TensorRT workflows (`ONNX_WORKFLOW=2` or `3`) require `tensorrt` and `modelopt`.
- FP8 quantization failure is fatal. INT8 failure falls back to the FP32 ONNX model silently.
- ONNX and engine files are written to the current working directory.

### GPU selection

Expand Down
223 changes: 214 additions & 9 deletions code/evaluation/logical_error_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import sys
import os
from enum import IntEnum
from pathlib import Path
from typing import Optional


Expand Down Expand Up @@ -50,6 +51,110 @@ def _detect_shm_bytes() -> Optional[int]:
return None


def _parse_quant_format(rank: int = 0) -> str:
"""Read and validate the QUANT_FORMAT environment variable.

Returns the validated format string ('int8' or 'fp8'), or '' if unset or invalid.
Prints a warning on rank 0 when the value is set but not recognised.
"""
quant_format = os.environ.get("QUANT_FORMAT", "").strip().lower()
if quant_format and quant_format not in ("int8", "fp8"):
if rank == 0:
print(f"[LER] Invalid QUANT_FORMAT='{quant_format}', ignoring. Supported: int8, fp8")
quant_format = ""
return quant_format


def _collect_calibration_dets(
test_dataloader,
num_obs: int,
target_samples: int,
expected_width: int,
) -> "np.ndarray":
"""Collect representative detector inputs from a dataloader for ONNX calibration.

Args:
test_dataloader: DataLoader yielding batches with a "dets_and_obs" key.
num_obs: Number of observable columns at the end of dets_and_obs to strip.
target_samples: Desired number of calibration rows.
expected_width: Expected number of detector columns after stripping observables.

Returns:
np.ndarray of shape (target_samples, expected_width), dtype uint8.
"""
if num_obs < 1:
raise ValueError(
f"num_obs must be >= 1, got {num_obs}. "
"dets_and_obs[:, :-0] would silently return an empty tensor."
)
target_samples = max(int(target_samples), 1)
chunks = []
collected = 0
for calib_batch in test_dataloader:
dets_and_obs_batch = calib_batch["dets_and_obs"]
dets_only_batch = dets_and_obs_batch[:, :-num_obs].to(torch.uint8).contiguous()
Comment thread
ivanbasov marked this conversation as resolved.
if int(dets_only_batch.shape[1]) != int(expected_width):
raise RuntimeError(
f"Calibration det width {dets_only_batch.shape[1]} != expected {expected_width}"
)
if dets_only_batch.numel() == 0:
continue
take = min(target_samples - collected, int(dets_only_batch.shape[0]))
if take > 0:
chunks.append(dets_only_batch[:take].cpu().numpy())
collected += take
if collected >= target_samples:
break
if not chunks:
raise RuntimeError("No calibration samples could be collected from test_dataloader.")
calib = np.concatenate(chunks, axis=0)
if calib.shape[0] < target_samples:
reps = int(np.ceil(target_samples / float(calib.shape[0])))
calib = np.tile(calib, (reps, 1))[:target_samples]
return np.ascontiguousarray(calib, dtype=np.uint8)


def _ort_quantize_int8(fp32_onnx_path: str, output_path: str, calib_dets: "np.ndarray") -> None:
"""INT8 static quantization via onnxruntime.quantization (Python 3.13+ fallback).

Used when nvidia-modelopt is unavailable (it does not support Python 3.13+).
Quantises all Conv and Gemm nodes with QInt8 weights and activations using
QDQ format, which is compatible with TensorRT INT8 parsing.

Args:
fp32_onnx_path: Path to the source FP32 ONNX model.
output_path: Destination path for the quantized ONNX model.
calib_dets: Calibration data array of shape (N, det_cols), dtype uint8.
"""
from onnxruntime.quantization import (
CalibrationDataReader,
QuantFormat,
QuantType,
quantize_static,
)

class _DetCalibReader(CalibrationDataReader):

def __init__(self, data):
self._rows = [{"dets": data[i:i + 1].astype("float32")} for i in range(len(data))]
self._iter = iter(self._rows)

def get_next(self):
return next(self._iter, None)

def rewind(self):
self._iter = iter(self._rows)

quantize_static(
fp32_onnx_path,
output_path,
_DetCalibReader(calib_dets),
quant_format=QuantFormat.QDQ,
activation_type=QuantType.QInt8,
weight_type=QuantType.QInt8,
)


def _time_single_shot_latency_stim(
matcher,
baseline_syndromes: np.ndarray,
Expand Down Expand Up @@ -878,8 +983,15 @@ def run_inference_and_decode_pre_decoder_memory(model, device, dist, cfg) -> dic
if dist.rank == 0:
print(f"[LER] Invalid ONNX_WORKFLOW='{_workflow_raw}', using 0 (torch only).")
trt_context = None # (context, engine, device_id) when using TensorRT
onnx_path = os.path.join(os.getcwd(), f"predecoder_memory_d{D}_T{T_original}_{basis}.onnx")
engine_path = os.path.join(os.getcwd(), f"predecoder_memory_d{D}_T{T_original}_{basis}.engine")
# --- QUANT_FORMAT: optional quantization (int8, fp8) applied to ONNX after FP32 export ---
quant_format = _parse_quant_format(rank=dist.rank)
quant_suffix = f"_{quant_format}" if quant_format else ""
Comment thread
ivanbasov marked this conversation as resolved.
onnx_path = os.path.join(
os.getcwd(), f"predecoder_memory_d{D}_T{T_original}_{basis}{quant_suffix}.onnx"
)
engine_path = os.path.join(
os.getcwd(), f"predecoder_memory_d{D}_T{T_original}_{basis}{quant_suffix}.engine"
)
half = (D * D - 1) // 2
example_shape = (batch_size_original, 2 * T_original * half)

Expand All @@ -901,6 +1013,11 @@ def run_inference_and_decode_pre_decoder_memory(model, device, dist, cfg) -> dic
print(
f"[LER] TensorRT engine loaded from {engine_path} in {t_load_end - t_load_start:.3f}s"
)
except ImportError as e:
raise RuntimeError(
"[LER] ONNX_WORKFLOW=3 (USE_ENGINE_ONLY) requires tensorrt to be installed. "
"Install with: pip install tensorrt"
) from e
except Exception as e:
if dist.rank == 0:
print(f"[LER] TensorRT engine load failed: {e}; falling back to PyTorch.")
Expand All @@ -915,10 +1032,16 @@ def run_inference_and_decode_pre_decoder_memory(model, device, dist, cfg) -> dic
if dist.rank == 0:
try:
example_dets = torch.randint(0, 2, example_shape, dtype=torch.uint8, device=device)

# Step 1: Always export FP32 ONNX first
fp32_onnx_path = (
onnx_path
if not quant_format else onnx_path.replace(f"_{quant_format}.onnx", ".onnx")
)
torch.onnx.export(
pipeline_module,
example_dets,
onnx_path,
fp32_onnx_path,
opset_version=18,
external_data=False,
input_names=["dets"],
Expand All @@ -934,27 +1057,84 @@ def run_inference_and_decode_pre_decoder_memory(model, device, dist, cfg) -> dic
do_constant_folding=True,
dynamo=False,
)
print(f"[LER] Exported ONNX: {onnx_path}")
print(f"[LER] Exported FP32 ONNX: {fp32_onnx_path}")

# Step 2: If QUANT_FORMAT is set, apply ONNX-level quantization.
# Backend: nvidia-modelopt on Python <3.13; onnxruntime on Python 3.13+
# (nvidia-modelopt does not support Python 3.13+).
if quant_format:
try:
num_obs_for_calib = circuit.num_observables
calib_num_samples = int(os.environ.get("QUANT_CALIB_SAMPLES", "256"))
print(
f"[LER] Collecting {calib_num_samples} calibration samples "
"from inference dataloader..."
)
calib_dets = _collect_calibration_dets(
test_dataloader, num_obs_for_calib, calib_num_samples, example_shape[1]
)

print(
f"[LER] Applying {quant_format.upper()} quantization to ONNX model..."
)
# Prefer modelopt (INT8+FP8); fall back to onnxruntime (INT8 only)
# when modelopt is not installed. On Python 3.13+ modelopt can
# be installed with: pip install nvidia-modelopt[onnx]
# --ignore-requires-python
try:
import modelopt.onnx.quantization as mq
quant_kwargs = {}
if quant_format == "fp8":
quant_kwargs["op_types_to_quantize"] = ["Conv"]
quant_kwargs["high_precision_dtype"] = "fp16"
mq.quantize(
onnx_path=fp32_onnx_path,
quantize_mode=quant_format,
calibration_data={"dets": calib_dets.astype("float32")},
output_path=onnx_path,
**quant_kwargs,
)
except ImportError:
if quant_format == "fp8":
raise RuntimeError(
"[LER] FP8 quantization requires nvidia-modelopt. "
"Install with: pip install 'nvidia-modelopt[onnx]'"
" --ignore-requires-python"
)
_ort_quantize_int8(fp32_onnx_path, onnx_path, calib_dets)
print(f"[LER] Exported quantized ONNX: {onnx_path}")
except Exception as e:
if quant_format == "fp8":
raise RuntimeError(
f"[LER] FP8 ONNX quantization failed (fail-fast): {e}"
) from e
print(f"[LER] ONNX quantization failed: {e}; using FP32 ONNX.")
onnx_path = fp32_onnx_path
except Exception as e:
if dist.rank == 0:
print(f"[LER] ONNX export failed: {e}; falling back to PyTorch.")
onnx_workflow = OnnxWorkflow.TORCH_ONLY
if dist.world_size > 1:
torch.distributed.barrier()
# Re-derive engine_path from the final onnx_path (may have changed on quant fallback)
engine_path = str(Path(onnx_path).with_suffix(".engine"))
if onnx_workflow == OnnxWorkflow.EXPORT_AND_USE_TRT and device.type == "cuda":
try:
import tensorrt as trt
logger = trt.Logger(trt.Logger.WARNING)
runtime = trt.Runtime(logger)
builder = trt.Builder(logger)
network = builder.create_network(
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
)
net_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
if quant_format in ("fp8", "int8"):
net_flags |= 1 << int(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED)
network = builder.create_network(net_flags)
parser = trt.OnnxParser(network, logger)
with open(onnx_path, "rb") as f:
if not parser.parse(f.read()):
raise RuntimeError("TensorRT ONNX parse failed")
config = builder.create_builder_config()
if not quant_format:
config.set_flag(trt.BuilderFlag.FP16)
# Uncomment this out to speedup engine build time.
# config.builder_optimization_level = 0
in_name = "dets"
Expand Down Expand Up @@ -989,6 +1169,29 @@ def run_inference_and_decode_pre_decoder_memory(model, device, dist, cfg) -> dic
)
if dist.rank == 0:
print(f"[LER] TensorRT engine built from {onnx_path}")
inspector = engine.create_engine_inspector()
if inspector is not None:
layer_info = inspector.get_engine_information(
trt.LayerInformationFormat.JSON
)
import json as _json
try:
info = _json.loads(layer_info)
layers = info.get("Layers", [])
precision_counts: dict = {}
for layer in layers:
prec = layer.get(
"LayerPrecision", layer.get("Precision", "unknown")
)
precision_counts[prec] = precision_counts.get(prec, 0) + 1
print(f"[LER] TensorRT engine layer precisions: {precision_counts}")
except Exception:
pass
except ImportError as e:
raise RuntimeError(
"[LER] ONNX_WORKFLOW=2 (EXPORT_AND_USE_TRT) requires tensorrt to be installed. "
"Install with: pip install tensorrt"
) from e
except Exception as e:
if dist.rank == 0:
print(f"[LER] TensorRT build/load failed: {e}; falling back to PyTorch.")
Expand Down Expand Up @@ -1030,11 +1233,13 @@ def run_inference_and_decode_pre_decoder_memory(model, device, dist, cfg) -> dic
int(dets.data_ptr()),
int(L_and_residual_dets.data_ptr()),
]
print(f"[LER] Executing TensorRT context with bindings...")
t_execute_start = time.perf_counter()
context.execute_v2(bindings=bindings)
t_execute_end = time.perf_counter()
print(f"[LER] TensorRT execution completed in {t_execute_end - t_execute_start:.3f}s")
if batch_idx == 0 and dist.rank == 0:
print(
f"[LER] TensorRT first batch executed in {t_execute_end - t_execute_start:.3f}s"
)
else:
L_and_residual_dets = pipeline_module(dets_only)
pre_L = L_and_residual_dets[:, 0].to(torch.int32)
Expand Down
4 changes: 4 additions & 0 deletions code/requirements_public_inference.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,7 @@ stim
pymatching
matplotlib
safetensors>=0.4.0
# Optional GPU-only prerequisite (not pip-installed here due to size and CUDA dependency):
# tensorrt -- required for ONNX_WORKFLOW=2 (EXPORT_AND_USE_TRT) and ONNX_WORKFLOW=3
# (USE_ENGINE_ONLY). Install via: pip install tensorrt
# Absent at runtime raises RuntimeError (workflows 2/3 require TensorRT).
7 changes: 7 additions & 0 deletions code/requirements_public_train.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,10 @@
-r requirements_public_inference.txt
tensorboard
torchinfo
# ONNX quantization (INT8/FP8 via QUANT_FORMAT).
# nvidia-modelopt[onnx] officially caps at Python <3.13 but works on 3.13 in practice.
# check_python_compat.sh installs it with --ignore-requires-python on Python 3.13+.
# For manual installs on Python 3.13+: pip install nvidia-modelopt[onnx] --ignore-requires-python
# onnxruntime is the INT8-only fallback when modelopt is not importable.
nvidia-modelopt[onnx]; python_version < "3.13"
onnxruntime; python_version >= "3.13"
11 changes: 11 additions & 0 deletions code/scripts/check_python_compat.sh
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,17 @@ if [[ -n "${TORCH_WHL_INDEX}" || -n "${TORCH_CUDA}" ]]; then
fi

pip install -r "${REQ_FILE}"

# nvidia-modelopt[onnx] caps its Python requirement at <3.13, but works on 3.13
# in practice. Re-install with --ignore-requires-python so train CI on 3.13+
# uses modelopt (INT8+FP8) rather than the onnxruntime-only fallback.
if [[ "${MODE}" == "train" ]]; then
if python -c "import sys; exit(0 if sys.version_info >= (3, 13) else 1)" 2>/dev/null; then
echo "Python 3.13+: installing nvidia-modelopt[onnx] --ignore-requires-python"
pip install "nvidia-modelopt[onnx]" --ignore-requires-python
fi
fi

if [[ -n "${EXTRA_PKGS}" ]]; then
pip install ${EXTRA_PKGS}
fi
Expand Down
Loading
Loading