diff --git a/README.md b/README.md index 5a6149a..42513af 100644 --- a/README.md +++ b/README.md @@ -131,6 +131,48 @@ WORKFLOW=inference bash code/scripts/local_run.sh ``` `MODEL_ID` is the public model identifier (1–5); see `model/registry.py` for the mapping. +The pre-trained public models use `--model-id 1` (R=9) and `--model-id 4` (R=13). + +### ONNX export and quantization (optional, post-training) + +After training (or starting from the shipped `.safetensors` files), you can export the model to +ONNX and optionally apply INT8 or FP8 post-training quantization for deployment. + +Set the `ONNX_WORKFLOW` and (optionally) `QUANT_FORMAT` environment variables before running +inference with `local_run.sh`: + +| `ONNX_WORKFLOW` | Behavior | +|---|---| +| `0` (default) | PyTorch inference only, no ONNX export | +| `1` | Export ONNX model and run inference with PyTorch | +| `2` | Export ONNX model and run inference via TensorRT | +| `3` | Load a pre-existing TensorRT engine file and run inference | + +```bash +# Export ONNX only (no TensorRT) +ONNX_WORKFLOW=1 WORKFLOW=inference bash code/scripts/local_run.sh + +# Export ONNX + apply INT8 quantization + run TensorRT inference +ONNX_WORKFLOW=2 QUANT_FORMAT=int8 WORKFLOW=inference bash code/scripts/local_run.sh + +# Export ONNX + apply FP8 quantization + run TensorRT inference +ONNX_WORKFLOW=2 QUANT_FORMAT=fp8 WORKFLOW=inference bash code/scripts/local_run.sh + +# Use a pre-built TensorRT engine (skip export) +ONNX_WORKFLOW=3 WORKFLOW=inference bash code/scripts/local_run.sh +``` + +**Quantization variables:** + +| Variable | Default | Description | +|---|---|---| +| `QUANT_FORMAT` | unset | `int8` or `fp8`. Unset means no quantization (FP32 ONNX). | +| `QUANT_CALIB_SAMPLES` | `256` | Calibration samples for INT8/FP8 post-training quantization. | + +Notes: +- TensorRT workflows (`ONNX_WORKFLOW=2` or `3`) require `tensorrt` and `modelopt`. +- FP8 quantization failure is fatal. INT8 failure falls back to the FP32 ONNX model silently. +- ONNX and engine files are written to the current working directory. ### GPU selection diff --git a/code/evaluation/logical_error_rate.py b/code/evaluation/logical_error_rate.py index 015d6e4..fad7f27 100644 --- a/code/evaluation/logical_error_rate.py +++ b/code/evaluation/logical_error_rate.py @@ -15,6 +15,7 @@ import sys import os from enum import IntEnum +from pathlib import Path from typing import Optional @@ -50,6 +51,110 @@ def _detect_shm_bytes() -> Optional[int]: return None +def _parse_quant_format(rank: int = 0) -> str: + """Read and validate the QUANT_FORMAT environment variable. + + Returns the validated format string ('int8' or 'fp8'), or '' if unset or invalid. + Prints a warning on rank 0 when the value is set but not recognised. + """ + quant_format = os.environ.get("QUANT_FORMAT", "").strip().lower() + if quant_format and quant_format not in ("int8", "fp8"): + if rank == 0: + print(f"[LER] Invalid QUANT_FORMAT='{quant_format}', ignoring. Supported: int8, fp8") + quant_format = "" + return quant_format + + +def _collect_calibration_dets( + test_dataloader, + num_obs: int, + target_samples: int, + expected_width: int, +) -> "np.ndarray": + """Collect representative detector inputs from a dataloader for ONNX calibration. + + Args: + test_dataloader: DataLoader yielding batches with a "dets_and_obs" key. + num_obs: Number of observable columns at the end of dets_and_obs to strip. + target_samples: Desired number of calibration rows. + expected_width: Expected number of detector columns after stripping observables. + + Returns: + np.ndarray of shape (target_samples, expected_width), dtype uint8. + """ + if num_obs < 1: + raise ValueError( + f"num_obs must be >= 1, got {num_obs}. " + "dets_and_obs[:, :-0] would silently return an empty tensor." + ) + target_samples = max(int(target_samples), 1) + chunks = [] + collected = 0 + for calib_batch in test_dataloader: + dets_and_obs_batch = calib_batch["dets_and_obs"] + dets_only_batch = dets_and_obs_batch[:, :-num_obs].to(torch.uint8).contiguous() + if int(dets_only_batch.shape[1]) != int(expected_width): + raise RuntimeError( + f"Calibration det width {dets_only_batch.shape[1]} != expected {expected_width}" + ) + if dets_only_batch.numel() == 0: + continue + take = min(target_samples - collected, int(dets_only_batch.shape[0])) + if take > 0: + chunks.append(dets_only_batch[:take].cpu().numpy()) + collected += take + if collected >= target_samples: + break + if not chunks: + raise RuntimeError("No calibration samples could be collected from test_dataloader.") + calib = np.concatenate(chunks, axis=0) + if calib.shape[0] < target_samples: + reps = int(np.ceil(target_samples / float(calib.shape[0]))) + calib = np.tile(calib, (reps, 1))[:target_samples] + return np.ascontiguousarray(calib, dtype=np.uint8) + + +def _ort_quantize_int8(fp32_onnx_path: str, output_path: str, calib_dets: "np.ndarray") -> None: + """INT8 static quantization via onnxruntime.quantization (Python 3.13+ fallback). + + Used when nvidia-modelopt is unavailable (it does not support Python 3.13+). + Quantises all Conv and Gemm nodes with QInt8 weights and activations using + QDQ format, which is compatible with TensorRT INT8 parsing. + + Args: + fp32_onnx_path: Path to the source FP32 ONNX model. + output_path: Destination path for the quantized ONNX model. + calib_dets: Calibration data array of shape (N, det_cols), dtype uint8. + """ + from onnxruntime.quantization import ( + CalibrationDataReader, + QuantFormat, + QuantType, + quantize_static, + ) + + class _DetCalibReader(CalibrationDataReader): + + def __init__(self, data): + self._rows = [{"dets": data[i:i + 1].astype("float32")} for i in range(len(data))] + self._iter = iter(self._rows) + + def get_next(self): + return next(self._iter, None) + + def rewind(self): + self._iter = iter(self._rows) + + quantize_static( + fp32_onnx_path, + output_path, + _DetCalibReader(calib_dets), + quant_format=QuantFormat.QDQ, + activation_type=QuantType.QInt8, + weight_type=QuantType.QInt8, + ) + + def _time_single_shot_latency_stim( matcher, baseline_syndromes: np.ndarray, @@ -878,8 +983,15 @@ def run_inference_and_decode_pre_decoder_memory(model, device, dist, cfg) -> dic if dist.rank == 0: print(f"[LER] Invalid ONNX_WORKFLOW='{_workflow_raw}', using 0 (torch only).") trt_context = None # (context, engine, device_id) when using TensorRT - onnx_path = os.path.join(os.getcwd(), f"predecoder_memory_d{D}_T{T_original}_{basis}.onnx") - engine_path = os.path.join(os.getcwd(), f"predecoder_memory_d{D}_T{T_original}_{basis}.engine") + # --- QUANT_FORMAT: optional quantization (int8, fp8) applied to ONNX after FP32 export --- + quant_format = _parse_quant_format(rank=dist.rank) + quant_suffix = f"_{quant_format}" if quant_format else "" + onnx_path = os.path.join( + os.getcwd(), f"predecoder_memory_d{D}_T{T_original}_{basis}{quant_suffix}.onnx" + ) + engine_path = os.path.join( + os.getcwd(), f"predecoder_memory_d{D}_T{T_original}_{basis}{quant_suffix}.engine" + ) half = (D * D - 1) // 2 example_shape = (batch_size_original, 2 * T_original * half) @@ -901,6 +1013,11 @@ def run_inference_and_decode_pre_decoder_memory(model, device, dist, cfg) -> dic print( f"[LER] TensorRT engine loaded from {engine_path} in {t_load_end - t_load_start:.3f}s" ) + except ImportError as e: + raise RuntimeError( + "[LER] ONNX_WORKFLOW=3 (USE_ENGINE_ONLY) requires tensorrt to be installed. " + "Install with: pip install tensorrt" + ) from e except Exception as e: if dist.rank == 0: print(f"[LER] TensorRT engine load failed: {e}; falling back to PyTorch.") @@ -915,10 +1032,16 @@ def run_inference_and_decode_pre_decoder_memory(model, device, dist, cfg) -> dic if dist.rank == 0: try: example_dets = torch.randint(0, 2, example_shape, dtype=torch.uint8, device=device) + + # Step 1: Always export FP32 ONNX first + fp32_onnx_path = ( + onnx_path + if not quant_format else onnx_path.replace(f"_{quant_format}.onnx", ".onnx") + ) torch.onnx.export( pipeline_module, example_dets, - onnx_path, + fp32_onnx_path, opset_version=18, external_data=False, input_names=["dets"], @@ -934,27 +1057,84 @@ def run_inference_and_decode_pre_decoder_memory(model, device, dist, cfg) -> dic do_constant_folding=True, dynamo=False, ) - print(f"[LER] Exported ONNX: {onnx_path}") + print(f"[LER] Exported FP32 ONNX: {fp32_onnx_path}") + + # Step 2: If QUANT_FORMAT is set, apply ONNX-level quantization. + # Backend: nvidia-modelopt on Python <3.13; onnxruntime on Python 3.13+ + # (nvidia-modelopt does not support Python 3.13+). + if quant_format: + try: + num_obs_for_calib = circuit.num_observables + calib_num_samples = int(os.environ.get("QUANT_CALIB_SAMPLES", "256")) + print( + f"[LER] Collecting {calib_num_samples} calibration samples " + "from inference dataloader..." + ) + calib_dets = _collect_calibration_dets( + test_dataloader, num_obs_for_calib, calib_num_samples, example_shape[1] + ) + + print( + f"[LER] Applying {quant_format.upper()} quantization to ONNX model..." + ) + # Prefer modelopt (INT8+FP8); fall back to onnxruntime (INT8 only) + # when modelopt is not installed. On Python 3.13+ modelopt can + # be installed with: pip install nvidia-modelopt[onnx] + # --ignore-requires-python + try: + import modelopt.onnx.quantization as mq + quant_kwargs = {} + if quant_format == "fp8": + quant_kwargs["op_types_to_quantize"] = ["Conv"] + quant_kwargs["high_precision_dtype"] = "fp16" + mq.quantize( + onnx_path=fp32_onnx_path, + quantize_mode=quant_format, + calibration_data={"dets": calib_dets.astype("float32")}, + output_path=onnx_path, + **quant_kwargs, + ) + except ImportError: + if quant_format == "fp8": + raise RuntimeError( + "[LER] FP8 quantization requires nvidia-modelopt. " + "Install with: pip install 'nvidia-modelopt[onnx]'" + " --ignore-requires-python" + ) + _ort_quantize_int8(fp32_onnx_path, onnx_path, calib_dets) + print(f"[LER] Exported quantized ONNX: {onnx_path}") + except Exception as e: + if quant_format == "fp8": + raise RuntimeError( + f"[LER] FP8 ONNX quantization failed (fail-fast): {e}" + ) from e + print(f"[LER] ONNX quantization failed: {e}; using FP32 ONNX.") + onnx_path = fp32_onnx_path except Exception as e: if dist.rank == 0: print(f"[LER] ONNX export failed: {e}; falling back to PyTorch.") onnx_workflow = OnnxWorkflow.TORCH_ONLY if dist.world_size > 1: torch.distributed.barrier() + # Re-derive engine_path from the final onnx_path (may have changed on quant fallback) + engine_path = str(Path(onnx_path).with_suffix(".engine")) if onnx_workflow == OnnxWorkflow.EXPORT_AND_USE_TRT and device.type == "cuda": try: import tensorrt as trt logger = trt.Logger(trt.Logger.WARNING) runtime = trt.Runtime(logger) builder = trt.Builder(logger) - network = builder.create_network( - 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) - ) + net_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) + if quant_format in ("fp8", "int8"): + net_flags |= 1 << int(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED) + network = builder.create_network(net_flags) parser = trt.OnnxParser(network, logger) with open(onnx_path, "rb") as f: if not parser.parse(f.read()): raise RuntimeError("TensorRT ONNX parse failed") config = builder.create_builder_config() + if not quant_format: + config.set_flag(trt.BuilderFlag.FP16) # Uncomment this out to speedup engine build time. # config.builder_optimization_level = 0 in_name = "dets" @@ -989,6 +1169,29 @@ def run_inference_and_decode_pre_decoder_memory(model, device, dist, cfg) -> dic ) if dist.rank == 0: print(f"[LER] TensorRT engine built from {onnx_path}") + inspector = engine.create_engine_inspector() + if inspector is not None: + layer_info = inspector.get_engine_information( + trt.LayerInformationFormat.JSON + ) + import json as _json + try: + info = _json.loads(layer_info) + layers = info.get("Layers", []) + precision_counts: dict = {} + for layer in layers: + prec = layer.get( + "LayerPrecision", layer.get("Precision", "unknown") + ) + precision_counts[prec] = precision_counts.get(prec, 0) + 1 + print(f"[LER] TensorRT engine layer precisions: {precision_counts}") + except Exception: + pass + except ImportError as e: + raise RuntimeError( + "[LER] ONNX_WORKFLOW=2 (EXPORT_AND_USE_TRT) requires tensorrt to be installed. " + "Install with: pip install tensorrt" + ) from e except Exception as e: if dist.rank == 0: print(f"[LER] TensorRT build/load failed: {e}; falling back to PyTorch.") @@ -1030,11 +1233,13 @@ def run_inference_and_decode_pre_decoder_memory(model, device, dist, cfg) -> dic int(dets.data_ptr()), int(L_and_residual_dets.data_ptr()), ] - print(f"[LER] Executing TensorRT context with bindings...") t_execute_start = time.perf_counter() context.execute_v2(bindings=bindings) t_execute_end = time.perf_counter() - print(f"[LER] TensorRT execution completed in {t_execute_end - t_execute_start:.3f}s") + if batch_idx == 0 and dist.rank == 0: + print( + f"[LER] TensorRT first batch executed in {t_execute_end - t_execute_start:.3f}s" + ) else: L_and_residual_dets = pipeline_module(dets_only) pre_L = L_and_residual_dets[:, 0].to(torch.int32) diff --git a/code/requirements_public_inference.txt b/code/requirements_public_inference.txt index e539f5c..20c7301 100644 --- a/code/requirements_public_inference.txt +++ b/code/requirements_public_inference.txt @@ -19,3 +19,7 @@ stim pymatching matplotlib safetensors>=0.4.0 +# Optional GPU-only prerequisite (not pip-installed here due to size and CUDA dependency): +# tensorrt -- required for ONNX_WORKFLOW=2 (EXPORT_AND_USE_TRT) and ONNX_WORKFLOW=3 +# (USE_ENGINE_ONLY). Install via: pip install tensorrt +# Absent at runtime raises RuntimeError (workflows 2/3 require TensorRT). diff --git a/code/requirements_public_train.txt b/code/requirements_public_train.txt index 3f70df4..5ca48d1 100644 --- a/code/requirements_public_train.txt +++ b/code/requirements_public_train.txt @@ -14,3 +14,10 @@ -r requirements_public_inference.txt tensorboard torchinfo +# ONNX quantization (INT8/FP8 via QUANT_FORMAT). +# nvidia-modelopt[onnx] officially caps at Python <3.13 but works on 3.13 in practice. +# check_python_compat.sh installs it with --ignore-requires-python on Python 3.13+. +# For manual installs on Python 3.13+: pip install nvidia-modelopt[onnx] --ignore-requires-python +# onnxruntime is the INT8-only fallback when modelopt is not importable. +nvidia-modelopt[onnx]; python_version < "3.13" +onnxruntime; python_version >= "3.13" diff --git a/code/scripts/check_python_compat.sh b/code/scripts/check_python_compat.sh index e4d35df..9f216fd 100644 --- a/code/scripts/check_python_compat.sh +++ b/code/scripts/check_python_compat.sh @@ -72,6 +72,17 @@ if [[ -n "${TORCH_WHL_INDEX}" || -n "${TORCH_CUDA}" ]]; then fi pip install -r "${REQ_FILE}" + +# nvidia-modelopt[onnx] caps its Python requirement at <3.13, but works on 3.13 +# in practice. Re-install with --ignore-requires-python so train CI on 3.13+ +# uses modelopt (INT8+FP8) rather than the onnxruntime-only fallback. +if [[ "${MODE}" == "train" ]]; then + if python -c "import sys; exit(0 if sys.version_info >= (3, 13) else 1)" 2>/dev/null; then + echo "Python 3.13+: installing nvidia-modelopt[onnx] --ignore-requires-python" + pip install "nvidia-modelopt[onnx]" --ignore-requires-python + fi +fi + if [[ -n "${EXTRA_PKGS}" ]]; then pip install ${EXTRA_PKGS} fi diff --git a/code/tests/test_onnx_quant_workflow.py b/code/tests/test_onnx_quant_workflow.py new file mode 100644 index 0000000..c2f4260 --- /dev/null +++ b/code/tests/test_onnx_quant_workflow.py @@ -0,0 +1,442 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. +"""Tests for ONNX quantization workflow: _collect_calibration_dets helper and QUANT_FORMAT env var logic.""" + +import os +import re +import sys +import types +import unittest +from pathlib import Path +from unittest.mock import MagicMock, patch + +import numpy as np +import torch + +_repo_code = Path(__file__).resolve().parent.parent +if str(_repo_code) not in sys.path: + sys.path.insert(0, str(_repo_code)) + +from evaluation.logical_error_rate import ( + _collect_calibration_dets, + _ort_quantize_int8, + _parse_quant_format, +) + + +def _make_fake_dataloader(num_batches: int, batch_size: int, num_dets: int, num_obs: int): + """Build a list of fake batches mimicking the test_dataloader interface.""" + batches = [] + for _ in range(num_batches): + dets_and_obs = torch.randint(0, 2, (batch_size, num_dets + num_obs), dtype=torch.uint8) + batches.append({"dets_and_obs": dets_and_obs}) + return batches + + +class TestCollectCalibrationDets(unittest.TestCase): + + NUM_DETS = 20 + NUM_OBS = 1 + + def test_basic_shape_and_dtype(self): + """Output must have shape (target_samples, NUM_DETS) and dtype uint8.""" + loader = _make_fake_dataloader( + num_batches=4, batch_size=32, num_dets=self.NUM_DETS, num_obs=self.NUM_OBS + ) + target = 64 + result = _collect_calibration_dets(loader, self.NUM_OBS, target, self.NUM_DETS) + self.assertEqual(result.shape, (target, self.NUM_DETS)) + self.assertEqual(result.dtype, np.uint8) + + def test_tiles_when_dataloader_too_short(self): + """When fewer samples are available than requested, tiles to fill target_samples.""" + loader = _make_fake_dataloader( + num_batches=1, batch_size=8, num_dets=self.NUM_DETS, num_obs=self.NUM_OBS + ) + target = 50 + result = _collect_calibration_dets(loader, self.NUM_OBS, target, self.NUM_DETS) + self.assertEqual(result.shape, (target, self.NUM_DETS)) + self.assertEqual(result.dtype, np.uint8) + + def test_empty_dataloader_raises(self): + """Empty dataloader (no batches) must raise RuntimeError.""" + loader = [] + with self.assertRaises(RuntimeError): + _collect_calibration_dets(loader, self.NUM_OBS, 32, self.NUM_DETS) + + def test_width_mismatch_raises(self): + """If det width doesn't match expected_width, raises RuntimeError.""" + loader = _make_fake_dataloader( + num_batches=2, batch_size=16, num_dets=self.NUM_DETS, num_obs=self.NUM_OBS + ) + wrong_width = self.NUM_DETS + 5 + with self.assertRaises(RuntimeError): + _collect_calibration_dets(loader, self.NUM_OBS, 16, wrong_width) + + def test_stops_early_when_enough_samples(self): + """Should stop iterating once target_samples are collected.""" + consumed = [] + num_dets = self.NUM_DETS + num_obs = self.NUM_OBS + + class CountingLoader: + + def __iter__(self): + for i in range(100): + consumed.append(i) + dets_and_obs = torch.randint(0, 2, (32, num_dets + num_obs), dtype=torch.uint8) + yield {"dets_and_obs": dets_and_obs} + + loader = CountingLoader() + target = 32 # exactly one batch + _collect_calibration_dets(loader, num_obs, target, num_dets) + self.assertEqual(len(consumed), 1) + + +class TestQuantFormatParsing(unittest.TestCase): + """Test QUANT_FORMAT env var parsing and routing logic (no GPU, no modelopt needed).""" + + def _run_quant_block(self, quant_format_env: str, mock_mq=None, mock_export=None): + """Invoke the real _parse_quant_format() from LER under a controlled env.""" + with patch.dict(os.environ, {"QUANT_FORMAT": quant_format_env}): + return _parse_quant_format(rank=0) + + def test_invalid_quant_format_ignored(self): + result = self._run_quant_block("bad_format") + self.assertEqual(result, "") + + def test_valid_int8_accepted(self): + result = self._run_quant_block("int8") + self.assertEqual(result, "int8") + + def test_valid_fp8_accepted(self): + result = self._run_quant_block("fp8") + self.assertEqual(result, "fp8") + + def test_nvfp4_rejected(self): + result = self._run_quant_block("nvfp4") + self.assertEqual(result, "") + + def test_empty_quant_format_no_quantize_call(self): + """With QUANT_FORMAT unset, mq.quantize must never be called.""" + mock_mq = MagicMock() + with patch.dict(os.environ, {"QUANT_FORMAT": ""}): + quant_format = os.environ.get("QUANT_FORMAT", "").strip().lower() + if quant_format: + mock_mq.quantize() + mock_mq.quantize.assert_not_called() + + def test_mq_quantize_called_with_correct_args_int8(self): + """With QUANT_FORMAT=int8, mq.quantize receives quantize_mode='int8' and calibration_data.""" + mock_mq = MagicMock() + num_dets = 20 + num_obs = 1 + loader = _make_fake_dataloader( + num_batches=2, batch_size=32, num_dets=num_dets, num_obs=num_obs + ) + + with patch.dict(os.environ, {"QUANT_FORMAT": "int8", "QUANT_CALIB_SAMPLES": "16"}): + quant_format = "int8" + fp32_path = "model.onnx" + quant_path = "model_int8.onnx" + calib_num_samples = int(os.environ.get("QUANT_CALIB_SAMPLES", "256")) + calib_dets = _collect_calibration_dets(loader, num_obs, calib_num_samples, num_dets) + format_map = {"int8": "int8", "fp8": "fp8"} + mock_mq.quantize( + onnx_path=fp32_path, + quantize_mode=format_map[quant_format], + calibration_data={"dets": calib_dets}, + output_path=quant_path, + ) + + mock_mq.quantize.assert_called_once() + call_kwargs = mock_mq.quantize.call_args + self.assertEqual(call_kwargs.kwargs["quantize_mode"], "int8") + self.assertIn("dets", call_kwargs.kwargs["calibration_data"]) + calib = call_kwargs.kwargs["calibration_data"]["dets"] + self.assertEqual(calib.shape, (calib_num_samples, num_dets)) + self.assertEqual(calib.dtype, np.uint8) + + def test_fp8_fail_fast_raises(self): + """With QUANT_FORMAT=fp8, if mq.quantize raises, a RuntimeError is propagated.""" + num_dets = 20 + num_obs = 1 + loader = _make_fake_dataloader( + num_batches=2, batch_size=32, num_dets=num_dets, num_obs=num_obs + ) + calib_dets = _collect_calibration_dets(loader, num_obs, 16, num_dets) + + quant_format = "fp8" + with self.assertRaises(RuntimeError): + try: + raise ValueError("simulated fp8 quantize failure") + except Exception as e: + if quant_format == "fp8": + raise RuntimeError( + f"[LER] FP8 ONNX quantization failed (fail-fast): {e}" + ) from e + pass # non-fp8 would fall through + + def test_non_fp8_failure_falls_back_to_fp32(self): + """With QUANT_FORMAT=int8, if mq.quantize raises, onnx_path falls back to fp32 path silently.""" + num_dets = 20 + num_obs = 1 + loader = _make_fake_dataloader( + num_batches=2, batch_size=32, num_dets=num_dets, num_obs=num_obs + ) + calib_dets = _collect_calibration_dets(loader, num_obs, 16, num_dets) + + quant_format = "int8" + fp32_onnx_path = "model.onnx" + onnx_path = "model_int8.onnx" # would be the quantized path + + try: + raise RuntimeError("simulated int8 quantize failure") + except Exception as e: + if quant_format == "fp8": + raise RuntimeError(f"fail-fast: {e}") from e + # non-fp8: fall back silently + onnx_path = fp32_onnx_path + + self.assertEqual(onnx_path, fp32_onnx_path) + + +class TestOrtQuantizeInt8(unittest.TestCase): + """Tests for the _ort_quantize_int8 helper (onnxruntime INT8 fallback).""" + + @unittest.skipUnless( + __import__("importlib").util.find_spec("onnxruntime") is not None and + __import__("importlib").util.find_spec("modelopt") is None, + "onnxruntime not installed or modelopt present (ort path is only the fallback when modelopt is absent)", + ) + def test_ort_quantize_int8_produces_output_file(self): + """_ort_quantize_int8 must write a valid ONNX file to output_path.""" + try: + import onnx + import onnx.helper as oh + import onnxruntime # noqa: F401 + except ImportError: + self.skipTest("onnx/onnxruntime not installed") + + import tempfile + import numpy as np + + # Build a tiny single-Gemm ONNX model compatible with quantize_static. + X = oh.make_tensor_value_info("dets", onnx.TensorProto.FLOAT, [1, 4]) + W_data = np.ones((4, 4), dtype=np.float32) + B_data = np.zeros((4,), dtype=np.float32) + W = oh.make_tensor("W", onnx.TensorProto.FLOAT, W_data.shape, W_data.flatten().tolist()) + B = oh.make_tensor("B", onnx.TensorProto.FLOAT, B_data.shape, B_data.flatten().tolist()) + Y = oh.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, [1, 4]) + node = oh.make_node("Gemm", inputs=["dets", "W", "B"], outputs=["Y"]) + graph = oh.make_graph([node], "tiny", [X], [Y], initializer=[W, B]) + model = oh.make_model(graph, opset_imports=[oh.make_opsetid("", 17)]) + # Pin to IR version 8 (opset-17 minimum). Newer ONNX packages default to + # IR version 12, which onnxruntime-gpu 1.22.0 (a modelopt dependency) rejects. + model.ir_version = 8 + onnx.checker.check_model(model) + + calib = np.random.randn(8, 4).astype(np.float32) + + with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as fp32_f: + fp32_path = fp32_f.name + with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as out_f: + out_path = out_f.name + self.addCleanup(os.unlink, fp32_path) + self.addCleanup(os.unlink, out_path) + + onnx.save(model, fp32_path) + _ort_quantize_int8(fp32_path, out_path, calib) + + quant_model = onnx.load(out_path) + onnx.checker.check_model(quant_model) + + def test_ort_quantize_int8_called_on_modelopt_import_error(self): + """When modelopt is not importable, INT8 must fall back to _ort_quantize_int8.""" + called = [] + with patch( + "evaluation.logical_error_rate._ort_quantize_int8", + side_effect=lambda *a, **kw: called.append(a), + ): + import evaluation.logical_error_rate as ler + ler._ort_quantize_int8("fp32.onnx", "out.onnx", None) + self.assertEqual(len(called), 1) + + def test_fp8_raises_on_modelopt_import_error(self): + """When modelopt is not importable, FP8 must raise RuntimeError (no ort fallback).""" + quant_format = "fp8" + with self.assertRaises(RuntimeError): + try: + raise ImportError("No module named 'modelopt'") + except ImportError: + if quant_format == "fp8": + raise RuntimeError( + "[LER] FP8 quantization requires nvidia-modelopt. " + "Install with: pip install 'nvidia-modelopt[onnx]'" + " --ignore-requires-python" + ) + + +_HAS_MODELOPT = __import__("importlib").util.find_spec("modelopt") is not None + + +class TestModeloptQuantize(unittest.TestCase): + """End-to-end tests that call mq.quantize() on a real ONNX model. + + Skipped when nvidia-modelopt is not installed. On Python 3.13+ modelopt + must be installed with --ignore-requires-python (done by check_python_compat.sh + when MODE=train); these tests confirm it actually works at runtime, not just + that the import succeeds. + """ + + def _build_tiny_model(self): + """Return (fp32_path, calib_dets) for a minimal Gemm ONNX model.""" + import tempfile + + import numpy as np + import onnx + import onnx.helper as oh + + X = oh.make_tensor_value_info("dets", onnx.TensorProto.FLOAT, [1, 4]) + W_data = np.ones((4, 4), dtype=np.float32) + B_data = np.zeros((4,), dtype=np.float32) + W = oh.make_tensor("W", onnx.TensorProto.FLOAT, W_data.shape, W_data.flatten().tolist()) + B = oh.make_tensor("B", onnx.TensorProto.FLOAT, B_data.shape, B_data.flatten().tolist()) + Y = oh.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, [1, 4]) + node = oh.make_node("Gemm", inputs=["dets", "W", "B"], outputs=["Y"]) + graph = oh.make_graph([node], "tiny", [X], [Y], initializer=[W, B]) + model = oh.make_model(graph, opset_imports=[oh.make_opsetid("", 17)]) + model.ir_version = 8 + onnx.checker.check_model(model) + + with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f: + fp32_path = f.name + self.addCleanup(os.unlink, fp32_path) + onnx.save(model, fp32_path) + + calib = np.random.randn(16, 4).astype(np.float32) + return fp32_path, calib + + @unittest.skipUnless(_HAS_MODELOPT, "nvidia-modelopt not installed") + def test_mq_quantize_int8_produces_valid_onnx(self): + """mq.quantize(quantize_mode='int8') must write a valid ONNX file.""" + import tempfile + + import modelopt.onnx.quantization as mq + import onnx + + fp32_path, calib = self._build_tiny_model() + with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f: + out_path = f.name + self.addCleanup(os.unlink, out_path) + + mq.quantize( + onnx_path=fp32_path, + quantize_mode="int8", + calibration_data={"dets": calib}, + output_path=out_path, + ) + + self.assertTrue(os.path.isfile(out_path), "quantized ONNX output file not created") + quant_model = onnx.load(out_path) + onnx.checker.check_model(quant_model) + + @unittest.skipUnless(_HAS_MODELOPT, "nvidia-modelopt not installed") + def test_mq_quantize_int8_output_differs_from_fp32(self): + """The quantized model must differ from the FP32 source (QDQ nodes added).""" + import tempfile + + import modelopt.onnx.quantization as mq + import onnx + + fp32_path, calib = self._build_tiny_model() + with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f: + out_path = f.name + self.addCleanup(os.unlink, out_path) + + mq.quantize( + onnx_path=fp32_path, + quantize_mode="int8", + calibration_data={"dets": calib}, + output_path=out_path, + ) + + fp32_model = onnx.load(fp32_path) + quant_model = onnx.load(out_path) + self.assertNotEqual( + len(fp32_model.graph.node), + len(quant_model.graph.node), + "quantized model should have more nodes (QDQ pairs) than the FP32 source", + ) + + +class TestModeloptPrerequisite(unittest.TestCase): + """Verify quantization package prerequisites are correctly declared.""" + + _TRAIN_REQS = Path(__file__).resolve().parent.parent / "requirements_public_train.txt" + + def test_nvidia_modelopt_in_train_requirements(self): + """nvidia-modelopt[onnx] must be listed in requirements_public_train.txt.""" + text = self._TRAIN_REQS.read_text() + self.assertTrue( + re.search(r"(?m)^nvidia-modelopt", text), + "nvidia-modelopt[onnx] must appear in requirements_public_train.txt; " + "it is used for INT8/FP8 quantization on Python <3.13.", + ) + + def test_onnxruntime_in_train_requirements(self): + """onnxruntime must be listed in requirements_public_train.txt for Python 3.13+.""" + text = self._TRAIN_REQS.read_text() + self.assertTrue( + re.search(r"(?m)^onnxruntime", text), + "onnxruntime must appear in requirements_public_train.txt; " + "it is the INT8 quantization backend on Python 3.13+ " + "(nvidia-modelopt does not support Python 3.13+).", + ) + + def test_quant_packages_absent_from_inference_requirements(self): + """nvidia-modelopt and onnxruntime must NOT appear in the inference requirements.""" + infer_reqs = self._TRAIN_REQS.parent / "requirements_public_inference.txt" + text = infer_reqs.read_text() + self.assertFalse( + re.search(r"(?m)^nvidia-modelopt", text), + "nvidia-modelopt must not be in requirements_public_inference.txt.", + ) + self.assertFalse( + re.search(r"(?m)^onnxruntime", text), + "onnxruntime must not be in requirements_public_inference.txt.", + ) + + def test_modelopt_importable_when_installed(self): + """When nvidia-modelopt[onnx] is installed, modelopt.onnx.quantization must be importable. + + On Python 3.13+ modelopt can be installed with --ignore-requires-python; + this test skips silently if the package is absent regardless of Python version. + """ + try: + import modelopt.onnx.quantization as mq # noqa: F401 + except ImportError: + self.skipTest("nvidia-modelopt[onnx] is not installed in this environment") + + def test_ort_importable_when_installed(self): + """onnxruntime.quantization must be importable when onnxruntime is installed.""" + try: + from onnxruntime.quantization import ( # noqa: F401 + CalibrationDataReader, + QuantFormat, + QuantType, + quantize_static, + ) + except ImportError: + self.skipTest("onnxruntime is not installed in this environment") + + +if __name__ == "__main__": + unittest.main() diff --git a/code/tests/test_safetensors_export.py b/code/tests/test_safetensors_export.py index 60f1a37..ca551f3 100644 --- a/code/tests/test_safetensors_export.py +++ b/code/tests/test_safetensors_export.py @@ -49,6 +49,7 @@ def test_round_trip_fp32(self): model = self._make_model("fp32") with tempfile.NamedTemporaryFile(suffix=".safetensors", delete=False) as f: path = f.name + self.addCleanup(os.unlink, path) save_safetensors(model, path, model_id=self.MODEL_ID, dtype="fp32") loaded, metadata = load_safetensors(path, device="cpu") @@ -62,6 +63,7 @@ def test_round_trip_fp16(self): model = self._make_model("fp16") with tempfile.NamedTemporaryFile(suffix=".safetensors", delete=False) as f: path = f.name + self.addCleanup(os.unlink, path) save_safetensors(model, path, model_id=self.MODEL_ID, dtype="fp16") loaded, metadata = load_safetensors(path, device="cpu") @@ -77,6 +79,7 @@ def test_metadata_model_id_auto_detect(self): model = self._make_model("fp32") with tempfile.NamedTemporaryFile(suffix=".safetensors", delete=False) as f: path = f.name + self.addCleanup(os.unlink, path) save_safetensors(model, path, model_id=self.MODEL_ID, dtype="fp32") loaded, metadata = load_safetensors(path, model_id=None, device="cpu") @@ -88,6 +91,7 @@ def test_missing_model_id_raises(self): from safetensors.torch import save_file with tempfile.NamedTemporaryFile(suffix=".safetensors", delete=False) as f: path = f.name + self.addCleanup(os.unlink, path) dummy = {"weight": torch.zeros(4)} save_file(dummy, path, metadata={"quant_format": "fp32"}) # no model_id key @@ -99,6 +103,7 @@ def test_invalid_dtype_raises(self): model = self._make_model("fp32") with tempfile.NamedTemporaryFile(suffix=".safetensors", delete=False) as f: path = f.name + self.addCleanup(os.unlink, path) with self.assertRaises(ValueError): save_safetensors(model, path, model_id=self.MODEL_ID, dtype="int8") @@ -144,6 +149,7 @@ def test_env_var_loads_fp32_model(self): model = ModelFactory.create_model(_build_minimal_cfg(self.MODEL_ID)) with tempfile.NamedTemporaryFile(suffix=".safetensors", delete=False) as f: path = f.name + self.addCleanup(os.unlink, path) save_safetensors(model, path, model_id=self.MODEL_ID, dtype="fp32") loaded, cfg = self._run_load_model(path) @@ -164,6 +170,7 @@ def test_env_var_loads_fp16_model_and_sets_flag(self): model = ModelFactory.create_model(_build_minimal_cfg(self.MODEL_ID)).half() with tempfile.NamedTemporaryFile(suffix=".safetensors", delete=False) as f: path = f.name + self.addCleanup(os.unlink, path) save_safetensors(model, path, model_id=self.MODEL_ID, dtype="fp16") loaded, cfg = self._run_load_model(path) diff --git a/code/tests/test_tensorrt_fallback.py b/code/tests/test_tensorrt_fallback.py new file mode 100644 index 0000000..0ebcb26 --- /dev/null +++ b/code/tests/test_tensorrt_fallback.py @@ -0,0 +1,160 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. +"""Tests for the TensorRT optional-dependency contract. + +tensorrt is a heavy CUDA-only package (~500 MB) that cannot be pip-installed +in CPU-only CI. It is therefore documented as an optional prerequisite via a +comment in requirements_public_inference.txt rather than as a pip requirement. +These tests verify: + 1. The documentation comment exists in the inference requirements file. + 2. Both TensorRT import sites raise RuntimeError (not silently fall back to + PyTorch) when tensorrt is unavailable (simulated via ImportError). + 3. Both sites still fall back gracefully for non-import TensorRT errors + (e.g. a corrupt engine file). + 4. When tensorrt is installed, its key symbols are importable (GPU CI only). +""" + +import re +import unittest +from pathlib import Path + + +class TestTensorrtDocumented(unittest.TestCase): + """tensorrt must be documented in requirements_public_inference.txt.""" + + _INFER_REQS = Path(__file__).resolve().parent.parent / "requirements_public_inference.txt" + + def test_tensorrt_mentioned_in_inference_requirements(self): + """A comment in requirements_public_inference.txt must reference tensorrt.""" + text = self._INFER_REQS.read_text() + self.assertTrue( + re.search(r"tensorrt", text), + "requirements_public_inference.txt must document tensorrt as an optional " + "GPU prerequisite (used by ONNX_WORKFLOW=2/3 paths). " + "Add a comment: '# tensorrt -- required for ONNX_WORKFLOW=2/3'.", + ) + + def test_tensorrt_not_a_pip_requirement(self): + """tensorrt must appear only in a comment, not as an active pip requirement.""" + text = self._INFER_REQS.read_text() + for line in text.splitlines(): + stripped = line.strip() + if stripped.startswith("#"): + continue + self.assertFalse( + stripped.startswith("tensorrt"), + "tensorrt must not be an active pip requirement in " + "requirements_public_inference.txt: it is a heavy CUDA-only SDK " + "that would break CPU CI. Document it in a comment instead.", + ) + + +class TestTensorrtMissingRaisesError(unittest.TestCase): + """Both TRT import sites must raise RuntimeError when tensorrt is not installed. + + Silently falling back to PyTorch would mask misconfiguration: the user + explicitly chose ONNX_WORKFLOW=2 or 3, so a missing tensorrt install is + always a hard error. + """ + + # --- helpers that mirror the two import sites in logical_error_rate.py --- + + def _simulate_trt_load(self, import_raises, other_error=False): + """Mirror the USE_ENGINE_ONLY (ONNX_WORKFLOW=3) import block.""" + trt_context = None + try: + if import_raises: + raise ImportError("No module named 'tensorrt'") + if other_error: + raise RuntimeError("engine deserialize failed") + trt_context = object() + except ImportError as e: + raise RuntimeError( + "[LER] ONNX_WORKFLOW=3 (USE_ENGINE_ONLY) requires tensorrt to be installed. " + "Install with: pip install tensorrt" + ) from e + except Exception as e: + # Non-import failures (bad engine file, etc.) fall back gracefully. + trt_context = None + return trt_context + + def _simulate_trt_build(self, import_raises, other_error=False): + """Mirror the EXPORT_AND_USE_TRT (ONNX_WORKFLOW=2) import block.""" + trt_context = None + try: + if import_raises: + raise ImportError("No module named 'tensorrt'") + if other_error: + raise RuntimeError("TRT build failed") + trt_context = object() + except ImportError as e: + raise RuntimeError( + "[LER] ONNX_WORKFLOW=2 (EXPORT_AND_USE_TRT) requires tensorrt to be installed. " + "Install with: pip install tensorrt" + ) from e + except Exception as e: + trt_context = None + return trt_context + + # --- import-error tests (must raise, not fall back) --- + + def test_use_engine_only_raises_on_import_error(self): + """USE_ENGINE_ONLY: missing tensorrt must raise RuntimeError, not fall back.""" + with self.assertRaises(RuntimeError) as ctx: + self._simulate_trt_load(import_raises=True) + self.assertIn("USE_ENGINE_ONLY", str(ctx.exception)) + self.assertIn("pip install tensorrt", str(ctx.exception)) + + def test_export_and_use_trt_raises_on_import_error(self): + """EXPORT_AND_USE_TRT: missing tensorrt must raise RuntimeError, not fall back.""" + with self.assertRaises(RuntimeError) as ctx: + self._simulate_trt_build(import_raises=True) + self.assertIn("EXPORT_AND_USE_TRT", str(ctx.exception)) + self.assertIn("pip install tensorrt", str(ctx.exception)) + + def test_import_error_chained_to_runtime_error(self): + """The RuntimeError must chain the original ImportError as __cause__.""" + with self.assertRaises(RuntimeError) as ctx: + self._simulate_trt_load(import_raises=True) + self.assertIsInstance(ctx.exception.__cause__, ImportError) + + # --- non-import error tests (must still fall back gracefully) --- + + def test_use_engine_only_falls_back_on_runtime_error(self): + """USE_ENGINE_ONLY: non-import TRT errors (bad engine) still fall back to PyTorch.""" + result = self._simulate_trt_load(import_raises=False, other_error=True) + self.assertIsNone(result) + + def test_export_and_use_trt_falls_back_on_runtime_error(self): + """EXPORT_AND_USE_TRT: non-import TRT errors (build failure) still fall back.""" + result = self._simulate_trt_build(import_raises=False, other_error=True) + self.assertIsNone(result) + + +class TestTensorrtImportable(unittest.TestCase): + """When tensorrt is installed, its key symbols must be importable (GPU CI only).""" + + @unittest.skipUnless( + # Only run when TRT is actually present; skip silently otherwise. + __import__("importlib").util.find_spec("tensorrt") is not None, + "tensorrt is not installed in this environment", + ) + def test_tensorrt_key_symbols(self): + import tensorrt as trt + required = ["Logger", "Runtime", "Builder", "BuilderFlag", "LayerInformationFormat"] + for sym in required: + self.assertTrue( + hasattr(trt, sym), + f"tensorrt.{sym} not found — TRT install may be incomplete.", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/code/workflows/run.py b/code/workflows/run.py index a1a87e2..927d260 100644 --- a/code/workflows/run.py +++ b/code/workflows/run.py @@ -108,29 +108,60 @@ def find_best_model(path, *, rank: int = 0): max_value = value best_file = filename except (IndexError, ValueError) as e: - print(f"⚠️ Warning: Could not parse epoch from filename {filename}: {e}") + print(f"Warning: could not parse epoch from filename {filename}: {e}") continue if rank == 0: - print(f"📊 Found {len(model_files)} model files:") + print(f"Found {len(model_files)} model files:") for filename, epoch in sorted(model_files, key=lambda x: x[1]): - marker = "👑" if filename == best_file else " " - print(f" {marker} {filename} (epoch {epoch})") + marker = "*" if filename == best_file else " " + print(f" [{marker}] {filename} (epoch {epoch})") if best_file is None: - raise FileNotFoundError(f"❌ No valid PreDecoderModelMemory files found in {path}") + raise FileNotFoundError(f"No valid PreDecoderModelMemory files found in {path}") - best_model_path = path + "/" + best_file + best_model_path = os.path.join(path, best_file) if rank == 0: - print(f"✅ Selected best model: {best_file} (epoch {max_value})") - print(f"📁 Full path: {best_model_path}") + print(f"Selected best model: {best_file} (epoch {max_value})") return best_model_path +def _resolve_dir(path: str) -> str: + """Return an absolute version of path, resolving relative paths from the repo root.""" + if os.path.isabs(path): + return path + repo_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + return os.path.join(repo_root, path) + + +def _load_state_dict_from_pt(model_path: str, device) -> dict: + """Load a state dict from a .pt checkpoint, handling multiple saved formats. + + Supports: + - bare state dict (keys are layer names) + - {"model_state_dict": ...} + - {"state_dict": ...} + Also strips the DDP "module." prefix if present. + """ + raw = torch.load(model_path, map_location=device, weights_only=False) + if isinstance(raw, dict): + if "model_state_dict" in raw: + state_dict = raw["model_state_dict"] + elif "state_dict" in raw: + state_dict = raw["state_dict"] + else: + state_dict = raw + else: + raise ValueError(f"Unexpected checkpoint format: expected a dict, got {type(raw).__name__}") + return { + (k[len("module."):] if k.startswith("module.") else k): v for k, v in state_dict.items() + } + + def _load_model(cfg, dist): if dist.rank == 0: - print(f"🚀 Loading model for task: {cfg.workflow.task}") + print(f"Loading model for task: {cfg.workflow.task}") _ensure_inference_io_channels(cfg) @@ -147,7 +178,6 @@ def _load_model(cfg, dist): model_id=None, device=str(dist.device), ) - model = torch.compile(model, disable=True) if dist.rank == 0: loaded_model_id = metadata.get("model_id", "unknown") dtype = metadata.get("quant_format", "fp32") @@ -156,15 +186,15 @@ def _load_model(cfg, dist): print(f" model_id: {loaded_model_id} (from SafeTensors metadata)") print(f" receptive_field: {receptive_field}") print(f" dtype: {dtype}") - print(f" Model parameters: {param_count:,}") + print(f" parameters: {param_count:,}") # Warn if config model_id doesn't match file metadata config_model_id = getattr(cfg, "model_id", None) if config_model_id is not None and str(config_model_id) != str(loaded_model_id): print( - f" Warning: config model_id={config_model_id} differs from file model_id={loaded_model_id}" + f" Warning: config model_id={config_model_id} differs from " + f"file model_id={loaded_model_id}; using {loaded_model_id}" ) - print(f" Using model_id={loaded_model_id} from SafeTensors file") if metadata.get("quant_format") == "fp16": cfg.enable_fp16 = True @@ -172,77 +202,43 @@ def _load_model(cfg, dist): model = ModelFactory.create_model(cfg).to(dist.device) - if dist.rank == 0: - print(f"Model architecture created and moved to device: {dist.device}") - - # Convert model to fp16 if enabled (consistent with training) if cfg.enable_fp16: model = model.half() if dist.rank == 0: - print(f"Model converted to float16 for fp16 inference") - - model = torch.compile(model, disable=True) - - if dist.rank == 0: - print(f"Model compilation disabled (for compatibility)") + print("Model converted to float16 for fp16 inference") # Determine model directory # Priority: 1) model_checkpoint_dir (for inference configs) # 2) cfg.output/models (for training configs) model_checkpoint_dir = getattr(cfg, 'model_checkpoint_dir', None) - - # Determine which model to load based on use_model_checkpoint use_checkpoint = getattr(cfg.test, 'use_model_checkpoint', -1) if use_checkpoint == -1: - # Load best model from best_model folder - if model_checkpoint_dir: - model_dir = os.path.join(model_checkpoint_dir, "best_model") - else: - model_dir = f"{cfg.output}/models/best_model" - - if dist.rank == 0: - print(f"📂 Loading best model (use_model_checkpoint=-1)") - - # If model_dir is relative, make it absolute - if not os.path.isabs(model_dir): - current_file = os.path.abspath(__file__) - project_root = os.path.dirname(os.path.dirname(os.path.dirname(current_file))) - model_dir = os.path.join(project_root, model_dir) - + model_dir = _resolve_dir( + os.path.join(model_checkpoint_dir, "best_model") + if model_checkpoint_dir else f"{cfg.output}/models/best_model" + ) if dist.rank == 0: - print(f"🔍 Resolved model directory: {model_dir}") + print(f"Loading best model from: {model_dir}") - # Fallback: older runs may not create a best_model/ folder; fall back to cfg.output/models. + # Fallback: older runs may not have a best_model/ folder if not os.path.isdir(model_dir): - fallback_dir = model_checkpoint_dir if model_checkpoint_dir else f"{cfg.output}/models" - if not os.path.isabs(fallback_dir): - current_file = os.path.abspath(__file__) - project_root = os.path.dirname(os.path.dirname(os.path.dirname(current_file))) - fallback_dir = os.path.join(project_root, fallback_dir) + fallback_dir = _resolve_dir( + model_checkpoint_dir if model_checkpoint_dir else f"{cfg.output}/models" + ) if dist.rank == 0: - print(f"⚠️ best_model folder not found; falling back to: {fallback_dir}") + print(f"best_model/ not found; falling back to: {fallback_dir}") model_dir = fallback_dir model_path = find_best_model(model_dir, rank=dist.rank) else: - # Load specific checkpoint from models folder - if model_checkpoint_dir: - checkpoint_dir = model_checkpoint_dir - else: - checkpoint_dir = f"{cfg.output}/models" - + checkpoint_dir = _resolve_dir( + model_checkpoint_dir if model_checkpoint_dir else f"{cfg.output}/models" + ) if dist.rank == 0: - print(f"📂 Loading checkpoint {use_checkpoint} (use_model_checkpoint={use_checkpoint})") - - # If checkpoint_dir is relative, make it absolute - if not os.path.isabs(checkpoint_dir): - current_file = os.path.abspath(__file__) - project_root = os.path.dirname(os.path.dirname(os.path.dirname(current_file))) - checkpoint_dir = os.path.join(project_root, checkpoint_dir) + print(f"Loading checkpoint {use_checkpoint} from: {checkpoint_dir}") - # Resolve checkpoint file: prefer any PreDecoderModelMemory_* ending with .0.{checkpoint}.pt - # (e.g. PreDecoderModelMemory_r9_v1.0.77.pt or PreDecoderModelMemory_v1.0.77.pt) + # Prefer any PreDecoderModelMemory_* file ending with .0.{use_checkpoint}.pt target_suffix = f".0.{use_checkpoint}.pt" checkpoint_filename = None try: @@ -256,23 +252,18 @@ def _load_model(cfg, dist): checkpoint_filename = f"PreDecoderModelMemory_v1.0.{use_checkpoint}.pt" model_path = os.path.join(checkpoint_dir, checkpoint_filename) - if dist.rank == 0: - print(f"🔍 Resolved checkpoint path: {model_path}") - if not os.path.exists(model_path): - raise FileNotFoundError(f"❌ Checkpoint not found: {model_path}") + raise FileNotFoundError(f"Checkpoint not found: {model_path}") if dist.rank == 0: - print(f"📥 Loading model parameters from: {model_path}") + print(f"Loading model parameters from: {model_path}") - model_params = torch.load(model_path, map_location=dist.device) - model.load_state_dict(model_params) + state_dict = _load_state_dict_from_pt(model_path, dist.device) + model.load_state_dict(state_dict) if dist.rank == 0: - print(f"✅ Model loaded successfully!") - # Show model size info param_count = sum(p.numel() for p in model.parameters()) - print(f"📊 Model parameters: {param_count:,}") + print(f"Model loaded ({param_count:,} parameters)") return model