diff --git a/README.md b/README.md index ad40921..55d982e 100644 --- a/README.md +++ b/README.md @@ -173,6 +173,62 @@ Notes: - TensorRT workflows (`ONNX_WORKFLOW=2` or `3`) require `tensorrt` and `modelopt`. - FP8 quantization failure is fatal. INT8 failure falls back to the FP32 ONNX model silently. - ONNX and engine files are written to the current working directory. +- `ONNX_WORKFLOW` is also honoured by the `decoder_ablation` workflow — see below. + +### Decoder ablation study with cudaq-qec (optional) + +The `decoder_ablation` workflow compares multiple global decoders on the residual syndromes left +by the neural pre-decoder. It supports both PyTorch and TensorRT backends for the pre-decoder +and GPU-accelerated global decoders from the `cudaq-qec` package (`cudaq_qec`). + +**PyTorch pre-decoder + cudaq-qec global decoders:** + +```bash +# Requires: cudaq-qec (cudaq_qec), ldpc, beliefmatching, scipy +WORKFLOW=decoder_ablation bash code/scripts/local_run.sh +``` + +**TRT pre-decoder + cudaq-qec global decoders (full GPU pipeline):** + +The same `ONNX_WORKFLOW` variable used for `inference` also applies here. When a TRT engine is +active, the neural pre-decoder runs via TensorRT (fast, quantised inference) while `cudaq-qec` +decoders handle the residual syndromes on GPU — combining fast TRT inference with +GPU-accelerated global decoding end-to-end. + +```bash +# Export ONNX, build TRT engine, run ablation (TRT pre-decoder + cudaq-qec) +ONNX_WORKFLOW=2 WORKFLOW=decoder_ablation bash code/scripts/local_run.sh + +# INT8 quantized TRT pre-decoder + cudaq-qec +ONNX_WORKFLOW=2 QUANT_FORMAT=int8 WORKFLOW=decoder_ablation bash code/scripts/local_run.sh + +# Load a previously built engine, then run ablation +ONNX_WORKFLOW=3 WORKFLOW=decoder_ablation bash code/scripts/local_run.sh +``` + +The ablation study reports per-decoder logical error rates, convergence statistics for +`cudaq-qec` BP variants, residual syndrome weight distributions, and timing breakdowns. +Results are written to `outputs//plots/`. + +**Decoder variants benchmarked:** + +| Decoder | Source | Notes | +|---|---|---| +| No-op | — | Pre-decoder output only, no global correction | +| Union-Find | `ldpc` | Fast, sub-optimal | +| BP-only | `ldpc` | Belief propagation, no OSD | +| BP+LSD-0 | `ldpc` | BP with localized statistics decoding | +| Uncorr-PM | PyMatching | Uncorrelated minimum-weight perfect matching | +| Corr-PM | PyMatching | Correlated MWPM (best classical baseline) | +| cudaq-BP | `cudaq-qec` | Sum-product BP on GPU | +| cudaq-MinSum | `cudaq-qec` | Min-sum BP on GPU | +| cudaq-BP+OSD-0/7 | `cudaq-qec` | BP + ordered statistics decoding | +| cudaq-MemBP | `cudaq-qec` | Memory-based min-sum BP | +| cudaq-MemBP+OSD | `cudaq-qec` | Memory BP + OSD | +| cudaq-RelayBP | `cudaq-qec` | Sequential relay composition | + +`cudaq-qec` decoders are loaded automatically when `cudaq_qec` is importable; the study +degrades gracefully to the non-cudaq decoders if the package is absent. ### GPU selection diff --git a/code/evaluation/failure_analysis.py b/code/evaluation/failure_analysis.py index e5ef945..4f9a4b1 100644 --- a/code/evaluation/failure_analysis.py +++ b/code/evaluation/failure_analysis.py @@ -23,8 +23,11 @@ import torch from evaluation.logical_error_rate import ( + OnnxWorkflow, + PreDecoderMemoryEvalModule, _build_stab_maps, _decode_batch, + _parse_quant_format, map_grid_to_stabilizer_tensor, sample_predictions, ) @@ -676,6 +679,191 @@ def _print_ablation_results( _plot_conditional_ler(weight_bucket_stats, decoder_names, basis, cfg) +def _setup_trt_for_ablation(model, cfg, dist, device, basis, D, half, stim_dets): + """ + Parse ONNX_WORKFLOW and, when requested, build or load a TensorRT engine + for the pre-decoder. + + Returns ``(trt_context, onnx_workflow)`` where *trt_context* is either a + ``(execution_context, engine)`` pair or ``None`` (PyTorch fallback). + """ + trt_context = None + onnx_workflow = OnnxWorkflow.TORCH_ONLY + try: + onnx_workflow = OnnxWorkflow(int(os.environ.get("ONNX_WORKFLOW", "0").strip())) + except ValueError: + pass + + if onnx_workflow == OnnxWorkflow.TORCH_ONLY: + return trt_context, onnx_workflow + + code_rotation = getattr(cfg.data, "code_rotation", "XV") + maps_dict = _build_stab_maps(D, code_rotation) + pipeline_module = PreDecoderMemoryEvalModule(model, cfg, maps_dict, device).to(device) + pipeline_module.eval() + + quant_format = _parse_quant_format(rank=dist.rank) + quant_suffix = f"_{quant_format}" if quant_format else "" + T_test = int(getattr(cfg.test, "n_rounds", cfg.n_rounds)) + onnx_path = os.path.join( + os.getcwd(), f"predecoder_memory_d{D}_T{T_test}_{basis}{quant_suffix}.onnx" + ) + engine_path = onnx_path.replace(".onnx", ".engine") + batch_size_onnx = int(getattr(cfg.test.dataloader, "batch_size", 2048)) + + if onnx_workflow == OnnxWorkflow.USE_ENGINE_ONLY and device.type == "cuda": + if os.path.isfile(engine_path): + try: + import tensorrt as trt + logger = trt.Logger(trt.Logger.WARNING) + runtime = trt.Runtime(logger) + t0 = _time.perf_counter() + with open(engine_path, "rb") as _f: + serialized = _f.read() + engine = runtime.deserialize_cuda_engine(serialized) + if engine is None: + raise RuntimeError("TensorRT engine deserialize failed") + trt_context = (engine.create_execution_context(), engine) + if dist.rank == 0: + print( + f"[Ablation] TensorRT engine loaded from {engine_path}" + f" in {_time.perf_counter() - t0:.2f}s" + ) + except Exception as e: + if dist.rank == 0: + print(f"[Ablation] TRT load failed: {e}; using PyTorch.") + else: + if dist.rank == 0: + print( + f"[Ablation] ONNX_WORKFLOW=3 but engine not found: {engine_path};" + " using PyTorch." + ) + + elif onnx_workflow in (OnnxWorkflow.EXPORT_ONNX_ONLY, OnnxWorkflow.EXPORT_AND_USE_TRT): + if dist.rank == 0: + try: + fp32_onnx_path = ( + onnx_path + if not quant_format else onnx_path.replace(f"_{quant_format}.onnx", ".onnx") + ) + # stim_dets shape is (N, num_detectors) = (N, (2*T+1)*half) — use it as sample input. + example_dets = torch.from_numpy(stim_dets[:batch_size_onnx] + ).to(device=device, dtype=torch.uint8) + torch.onnx.export( + pipeline_module, + example_dets, + fp32_onnx_path, + opset_version=18, + external_data=False, + input_names=["dets"], + output_names=["L_and_residual_dets"], + dynamic_axes={ + "dets": { + 0: "batch" + }, + "L_and_residual_dets": { + 0: "batch" + } + }, + do_constant_folding=True, + dynamo=False, + ) + print(f"[Ablation] Exported FP32 ONNX: {fp32_onnx_path}") + + if quant_format: + calib_samples = int(os.environ.get("QUANT_CALIB_SAMPLES", "256")) + calib_dets = stim_dets[:calib_samples].astype(np.uint8) + try: + import modelopt.onnx.quantization as mq + quant_kwargs = {} + if quant_format == "fp8": + quant_kwargs["op_types_to_quantize"] = ["Conv"] + quant_kwargs["high_precision_dtype"] = "fp16" + mq.quantize( + onnx_path=fp32_onnx_path, + quantize_mode=quant_format, + calibration_data={"dets": calib_dets.astype("float32")}, + output_path=onnx_path, + **quant_kwargs, + ) + except ImportError: + if quant_format == "fp8": + raise RuntimeError( + "[Ablation] FP8 quantization requires nvidia-modelopt." + ) + from evaluation.logical_error_rate import _ort_quantize_int8 + _ort_quantize_int8(fp32_onnx_path, onnx_path, calib_dets) + print(f"[Ablation] Exported quantized ONNX: {onnx_path}") + except Exception as e: + print(f"[Ablation] ONNX export failed: {e}; using PyTorch.") + onnx_workflow = OnnxWorkflow.TORCH_ONLY + + if dist.world_size > 1: + # Broadcast rank 0's onnx_workflow (may have been set to TORCH_ONLY on + # export failure) so non-zero ranks skip the TRT build when rank 0 failed. + wf_list = [onnx_workflow] + torch.distributed.broadcast_object_list(wf_list, src=0) + onnx_workflow = wf_list[0] + engine_path = onnx_path.replace(".onnx", ".engine") + + if onnx_workflow == OnnxWorkflow.EXPORT_AND_USE_TRT and device.type == "cuda": + try: + import tensorrt as trt + logger = trt.Logger(trt.Logger.WARNING) + runtime = trt.Runtime(logger) + builder = trt.Builder(logger) + net_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) + if quant_format in ("fp8", "int8"): + net_flags |= 1 << int(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED) + network = builder.create_network(net_flags) + parser = trt.OnnxParser(network, logger) + _onnx_to_parse = ( + onnx_path if os.path.isfile(onnx_path) else + onnx_path.replace(f"_{quant_format}.onnx", ".onnx") + ) + with open(_onnx_to_parse, "rb") as _f: + if not parser.parse(_f.read()): + raise RuntimeError("TensorRT ONNX parse failed") + config = builder.create_builder_config() + if not quant_format: + config.set_flag(trt.BuilderFlag.FP16) + in_cols_trt = 2 * T_test * half + profile = builder.create_optimization_profile() + profile.set_shape( + "dets", + (1, in_cols_trt), + (batch_size_onnx, in_cols_trt), + (batch_size_onnx, in_cols_trt), + ) + config.add_optimization_profile(profile) + t0_build = _time.perf_counter() + serialized = builder.build_serialized_network(network, config) + if serialized is None: + raise RuntimeError("TensorRT build failed") + if dist.rank == 0: + print( + f"[Ablation] TRT engine built in" + f" {_time.perf_counter() - t0_build:.1f}s" + ) + engine = runtime.deserialize_cuda_engine(serialized) + if dist.rank == 0: + with open(engine_path, "wb") as _f: + _f.write(engine.serialize()) + print(f"[Ablation] TRT engine saved to {engine_path}") + trt_context = (engine.create_execution_context(), engine) + except ImportError as e: + raise RuntimeError( + "[Ablation] EXPORT_AND_USE_TRT requires tensorrt." + " Install with: pip install tensorrt" + ) from e + except Exception as e: + if dist.rank == 0: + print(f"[Ablation] TRT build failed: {e}; using PyTorch.") + trt_context = None + + return trt_context, onnx_workflow + + @torch.inference_mode() def decoder_ablation_study(model, device, dist, cfg): """ @@ -760,6 +948,18 @@ def decoder_ablation_study(model, device, dist, cfg): Hx_idx, Hz_idx, Hx_mask, Hz_mask, stab_indices_x, stab_indices_z, Kx, Kz, Lx, Lz = \ _build_logical_operators(D, code_rotation, device) D2 = D * D + half = (D * D - 1) // 2 + + # --- TRT/ONNX setup --- + # Honours the same ONNX_WORKFLOW env-var as the inference workflow: + # 0 = PyTorch only 1 = export ONNX (then use PyTorch) + # 2 = export ONNX + build TRT engine 3 = load pre-built engine + # When a TRT engine is active the pre-decoder runs at TRT speed while + # cudaq-qec decoders handle the residual syndromes on GPU — combining + # fast TRT inference with GPU-accelerated global decoding end-to-end. + trt_context, onnx_workflow = _setup_trt_for_ablation( + model, cfg, dist, device, basis, D, half, stim_dets + ) if dist.rank == 0: print( @@ -779,6 +979,11 @@ def decoder_ablation_study(model, device, dist, cfg): f"[Decoder Ablation] Decoders: No-op, Union-Find, BP+LSD-0," f" Uncorr PM, Corr PM, {cudaq_names_str}, + Baseline PM" ) + _backend = ( + f"TRT (ONNX_WORKFLOW={onnx_workflow.value})" + if trt_context is not None else f"PyTorch (ONNX_WORKFLOW={onnx_workflow.value})" + ) + print(f"[Decoder Ablation] Pre-decoder backend: {_backend}") # --- Batch loop --- batch_size = int(getattr(cfg.test.dataloader, "batch_size", 2048)) @@ -817,27 +1022,44 @@ def decoder_ablation_study(model, device, dist, cfg): } for cn in cudaq_decoder_names } + # Cache the fixed output column count (1 + num_detectors) so we avoid a + # per-batch engine query inside the hot loop. The batch dimension is + # dynamic; index [1] always returns the constant column count. + _trt_out_ncols = None + if trt_context is not None: + _trt_ctx_pre, _ = trt_context + _trt_out_ncols = int(_trt_ctx_pre.get_tensor_shape("L_and_residual_dets")[1]) + for batch_idx in range(num_batches): start = batch_idx * batch_size end = min(start + batch_size, N) B = end - start + # Baseline detectors/obs are needed for both TRT and PyTorch paths. + baseline_detectors_batch = stim_dets[start:end] + gt_obs_batch = stim_obs[start:end] + _t0 = _time.perf_counter() - items = [test_dataset[i] for i in range(start, end)] - x_syn_diff = torch.stack([it["x_syn_diff"] for it in items] - ).to(device=device, dtype=torch.int32) - z_syn_diff = torch.stack([it["z_syn_diff"] for it in items] - ).to(device=device, dtype=torch.int32) - trainX = torch.stack([it["trainX"] for it in items]).to(device=device) + if trt_context is None: + # PyTorch path: need preprocessed grid tensors from dataset items. + items = [test_dataset[i] for i in range(start, end)] + x_syn_diff = torch.stack([it["x_syn_diff"] for it in items] + ).to(device=device, dtype=torch.int32) + z_syn_diff = torch.stack([it["z_syn_diff"] for it in items] + ).to(device=device, dtype=torch.int32) + trainX = torch.stack([it["trainX"] for it in items]).to(device=device) _timing["collate"] += _time.perf_counter() - _t0 - _, _, T = x_syn_diff.shape + if trt_context is not None: + # T derived from flat dets width: shape is (B, (2*T+1)*half) incl. boundary detectors. + T = baseline_detectors_batch.shape[1] // (2 * half) + else: + _, _, T = x_syn_diff.shape if T < 2: continue - # Baseline: raw Stim syndromes + ground truth - baseline_detectors_batch = stim_dets[start:end] - gt_obs_batch = stim_obs[start:end] + # Weight accumulation must happen after the T < 2 guard so that skipped + # batches do not inflate baseline weight counts. all_baseline_weights.extend(baseline_detectors_batch.sum(axis=1).tolist()) _t0 = _time.perf_counter() @@ -848,41 +1070,68 @@ def decoder_ablation_study(model, device, dist, cfg): gt_obs_np = gt_obs_batch.reshape(-1).astype(np.int64) - # Pre-decoder forward pass + residual syndrome construction + # Pre-decoder forward pass + residual syndrome construction. + # TRT path: feed raw dets directly to the TRT engine, which runs the full + # PreDecoderMemoryEvalModule pipeline (preprocessing → Conv3D → residual + # assembly) in a single optimised kernel graph. The output L_and_residual_dets + # has the same layout as the PyTorch path: col 0 = pre_L, cols 1: = residual + # dets ready for cudaq-qec and other global decoders. _t0 = _time.perf_counter() - residual_np, pre_L_np = _model_forward_and_residual( - model, - trainX, - x_syn_diff, - z_syn_diff, - basis, - B, - D2, - T, - Hx_idx, - Hz_idx, - Hx_mask, - Hz_mask, - Kx, - Kz, - stab_indices_x, - stab_indices_z, - Lx, - Lz, - th_data, - th_syn, - sampling_mode, - temperature_data, - temperature_syn, - cfg, - device, - num_boundary_dets, - baseline_detectors_batch, - det_model, - ) - if device.type == "cuda": - torch.cuda.synchronize() - _timing["residual_build"] += _time.perf_counter() - _t0 + if trt_context is not None: + # Pinned-memory transfer avoids an intermediate CPU allocation and + # lets the H2D copy overlap with CPU work (non_blocking=True). + dets_batch = torch.as_tensor(baseline_detectors_batch, dtype=torch.uint8 + ).pin_memory().to(device, non_blocking=True) + context, _engine = trt_context + context.set_input_shape("dets", dets_batch.shape) + L_and_residual_out = torch.empty((B, _trt_out_ncols), device=device, dtype=torch.uint8) + # Note: execute_v2 (binding-list API) is deprecated in TRT >= 10; + # migrate to set_tensor_address + execute_async_v3 when upgrading. + context.execute_v2( + bindings=[int(dets_batch.data_ptr()), + int(L_and_residual_out.data_ptr())] + ) + if device.type == "cuda": + torch.cuda.synchronize() + # Single D2H transfer then slice on CPU to avoid two round trips. + out_cpu = L_and_residual_out.cpu().numpy() + pre_L_np = out_cpu[:, 0].astype(np.int64) + residual_np = out_cpu[:, 1:] + _timing["model_fwd"] += _time.perf_counter() - _t0 + else: + residual_np, pre_L_np = _model_forward_and_residual( + model, + trainX, + x_syn_diff, + z_syn_diff, + basis, + B, + D2, + T, + Hx_idx, + Hz_idx, + Hx_mask, + Hz_mask, + Kx, + Kz, + stab_indices_x, + stab_indices_z, + Lx, + Lz, + th_data, + th_syn, + sampling_mode, + temperature_data, + temperature_syn, + cfg, + device, + num_boundary_dets, + baseline_detectors_batch, + det_model, + ) + if device.type == "cuda": + torch.cuda.synchronize() + _timing["residual_build"] += _time.perf_counter() - _t0 weights = residual_np.sum(axis=1) all_residual_weights.extend(weights.tolist()) diff --git a/code/scripts/local_run.sh b/code/scripts/local_run.sh index e072439..ba117a4 100644 --- a/code/scripts/local_run.sh +++ b/code/scripts/local_run.sh @@ -24,6 +24,22 @@ set -euo pipefail # GPUS=4 bash code/scripts/local_run.sh # CUDA_VISIBLE_DEVICES=1 bash code/scripts/local_run.sh # use only GPU 1 # +# ONNX / TRT fast inference (requires tensorrt; set ONNX_WORKFLOW before running): +# ONNX_WORKFLOW=1 WORKFLOW=inference bash code/scripts/local_run.sh # export ONNX only (inspect/reuse later) +# ONNX_WORKFLOW=2 WORKFLOW=inference bash code/scripts/local_run.sh # export ONNX + build TRT + run TRT inference +# ONNX_WORKFLOW=2 QUANT_FORMAT=int8 WORKFLOW=inference bash code/scripts/local_run.sh # INT8 quantized TRT +# ONNX_WORKFLOW=2 QUANT_FORMAT=fp8 WORKFLOW=inference bash code/scripts/local_run.sh # FP8 quantized TRT (requires nvidia-modelopt) +# ONNX_WORKFLOW=3 WORKFLOW=inference bash code/scripts/local_run.sh # load pre-built engine, skip export +# +# Decoder ablation study with cudaq-qec global decoders (requires cudaq-qec): +# WORKFLOW=decoder_ablation bash code/scripts/local_run.sh +# +# Decoder ablation with TRT pre-decoder + cudaq-qec global decoders +# (combines fast TRT inference for the neural pre-decoder with GPU-accelerated +# cudaq-qec decoders for the residual syndromes — full GPU pipeline end-to-end): +# ONNX_WORKFLOW=2 WORKFLOW=decoder_ablation bash code/scripts/local_run.sh # export+build TRT, then ablation +# ONNX_WORKFLOW=3 WORKFLOW=decoder_ablation bash code/scripts/local_run.sh # load existing engine, then ablation +# # Notes: # - Public config is `conf/config_public.yaml`. Users should edit only that file. # - Training knobs are auto-managed in code (epochs, shots/epoch, batch schedule, etc.). diff --git a/code/tests/test_failure_analysis.py b/code/tests/test_failure_analysis.py index 31067e6..36ff37c 100644 --- a/code/tests/test_failure_analysis.py +++ b/code/tests/test_failure_analysis.py @@ -848,5 +848,518 @@ def test_cudaq_error_counts_are_non_negative(self): self.assertLessEqual(result["decoder_errors"]["cudaq-BP"], result["total_samples"]) +class _MockCUDADevice: + """ + CPU-compatible mock device that reports type='cuda' so the TRT guard + (``if device.type == "cuda"``) is exercised without a physical GPU. + + torch.device cannot be subclassed, so this is a plain Python object. + All torch factory functions and nn.Module.to() that receive this device + must be patched (see _patch_tensor_to_for_mock_cuda) to redirect to the + real CPU device before reaching PyTorch's C layer. + """ + type = "cuda" + index = 0 + + def __str__(self): + return "cuda:0" + + def __repr__(self): + return "device(type='cuda', index=0)" + + +def _make_mock_trt_module(num_detectors): + """ + Build a minimal tensorrt mock whose execution context produces an all-zero + L_and_residual_dets tensor of shape (B, 1 + num_detectors). + + execute_v2 is a no-op; the pre-allocated output tensor stays at zero, which + is a valid (all-correct pre-decoder) output for testing purposes. + """ + B_holder = [None] + + class _Ctx: + + def set_input_shape(self, name, shape): + B_holder[0] = shape[0] + + def get_tensor_shape(self, name): + return (B_holder[0], 1 + num_detectors) + + def execute_v2(self, bindings): + pass # output tensor remains zeroed — valid binary values + + @property + def _engine(self): + return None # not accessed in ablation path + + class _Engine: + + def create_execution_context(self): + return _Ctx() + + def serialize(self): + return b"" + + class _Runtime: + + def deserialize_cuda_engine(self, data): + return _Engine() + + class _Logger: + WARNING = 1 + + class _BuilderFlag: + FP16 = 0 + + class _NetworkDefinitionCreationFlag: + EXPLICIT_BATCH = 0 + STRONGLY_TYPED = 1 + + class _OnnxParser: + + def __init__(self, network, logger): + pass + + def parse(self, data): + return True + + class _Profile: + + def set_shape(self, name, mn, opt, mx): + pass + + class _BuilderConfig: + + def set_flag(self, flag): + pass + + def add_optimization_profile(self, profile): + pass + + class _Network: + pass + + class _Builder: + + def create_network(self, flags): + return _Network() + + def create_optimization_profile(self): + return _Profile() + + def create_builder_config(self): + return _BuilderConfig() + + def build_serialized_network(self, network, config): + return b"" + + mock_trt = types.ModuleType("tensorrt") + mock_trt.Logger = _Logger + mock_trt.Runtime = lambda logger: _Runtime() + mock_trt.Builder = lambda logger: _Builder() + mock_trt.OnnxParser = _OnnxParser + mock_trt.BuilderFlag = _BuilderFlag() + mock_trt.NetworkDefinitionCreationFlag = _NetworkDefinitionCreationFlag() + return mock_trt + + +def _redirect_mock_device(v): + """Return torch.device("cpu") when v is a _MockCUDADevice; else v unchanged.""" + return torch.device("cpu") if isinstance(v, _MockCUDADevice) else v + + +def _patch_tensor_to_for_mock_cuda(): + """ + Return a context manager that allows TRT tests to run without a physical GPU. + + torch.device cannot be subclassed, so _MockCUDADevice is a plain Python + object. PyTorch's C layer rejects it in every tensor-creation call, so we + patch all relevant entry points to redirect _MockCUDADevice -> cpu before + the C layer sees it. torch.cuda.synchronize is stubbed to a no-op. + + Functions patched: + - torch.Tensor.to (Tensor moves) + - torch.nn.Module.to (model moves) + - torch.zeros/ones/empty/arange/full/rand/randn/randint/as_tensor/tensor + - torch.cuda.synchronize (no GPU available) + """ + from contextlib import contextmanager, ExitStack + + _FACTORY_NAMES = [ + "zeros", "ones", "empty", "arange", "full", "rand", "randn", "randint", "as_tensor", + "tensor" + ] + + @contextmanager + def _ctx(): + _orig_tensor_to = torch.Tensor.to + _orig_module_to = torch.nn.Module.to + + def _patched_tensor_to(self, *args, **kwargs): + return _orig_tensor_to( + self, *[_redirect_mock_device(a) for a in args], **{ + k: _redirect_mock_device(v) for k, v in kwargs.items() + } + ) + + def _patched_module_to(self, *args, **kwargs): + return _orig_module_to( + self, *[_redirect_mock_device(a) for a in args], **{ + k: _redirect_mock_device(v) for k, v in kwargs.items() + } + ) + + def _make_factory_patch(orig): + + def _patched(*args, **kwargs): + if "device" in kwargs: + kwargs["device"] = _redirect_mock_device(kwargs["device"]) + return orig(*args, **kwargs) + + return _patched + + with ExitStack() as stack: + stack.enter_context(patch.object(torch.Tensor, "to", _patched_tensor_to)) + stack.enter_context(patch.object(torch.nn.Module, "to", _patched_module_to)) + stack.enter_context(patch("torch.cuda.synchronize")) + for _name in _FACTORY_NAMES: + stack.enter_context( + patch.object(torch, _name, _make_factory_patch(getattr(torch, _name))) + ) + yield + + return _ctx() + + +class TestOnnxWorkflowParsing(unittest.TestCase): + """OnnxWorkflow env-var is read and falls back gracefully for invalid values.""" + + def test_default_is_torch_only(self): + from evaluation.logical_error_rate import OnnxWorkflow + # When ONNX_WORKFLOW is absent the default int is 0 → TORCH_ONLY. + with patch.dict("os.environ", {}, clear=True): + val = OnnxWorkflow(0) + self.assertEqual(val, OnnxWorkflow.TORCH_ONLY) + + def test_valid_values_parse(self): + from evaluation.logical_error_rate import OnnxWorkflow + for raw, expected in ( + ("0", OnnxWorkflow.TORCH_ONLY), ("1", OnnxWorkflow.EXPORT_ONNX_ONLY), + ("2", OnnxWorkflow.EXPORT_AND_USE_TRT), ("3", OnnxWorkflow.USE_ENGINE_ONLY) + ): + self.assertEqual(OnnxWorkflow(int(raw)), expected, f"raw={raw!r}") + + def test_invalid_value_raises_valueerror(self): + from evaluation.logical_error_rate import OnnxWorkflow + with self.assertRaises(ValueError): + OnnxWorkflow(99) + + +class TestDecoderAblationStudyTRTFallback(unittest.TestCase): + """ + ONNX_WORKFLOW=3 with a missing engine file must fall back to PyTorch silently + and produce the same result structure as the default PyTorch path. + """ + + _D = 3 + _T = 3 + _N = 8 + + @classmethod + def setUpClass(cls): + cls._result = cls._run_once("X") + + @classmethod + def _run_once(cls, basis="X"): + from evaluation.failure_analysis import decoder_ablation_study + from data.datapipe_stim import QCDataPipePreDecoder_Memory_inference + real_ds = QCDataPipePreDecoder_Memory_inference( + distance=cls._D, + n_rounds=cls._T, + num_samples=cls._N, + error_mode="circuit_level_surface_custom", + p_error=0.01, + measure_basis=basis, + code_rotation="XV", + ) + with tempfile.TemporaryDirectory() as tmpdir: + cfg = _make_cfg(tmpdir, distance=cls._D, n_rounds=cls._T, basis=basis, n_samples=cls._N) + cfg.test.n_rounds = cls._T + with patch("data.factory.DatapipeFactory") as mf, \ + patch.dict("os.environ", {"ONNX_WORKFLOW": "3"}), \ + patch("os.getcwd", return_value=tmpdir): + mf.create_datapipe_inference.return_value = real_ds + result = decoder_ablation_study( + _ZeroModel(), torch.device("cpu"), _DummyDist(), cfg + ) + return result + + def test_missing_engine_does_not_crash(self): + self.assertEqual(self._result["total_samples"], self._N) + + def test_missing_engine_result_structure_intact(self): + for key in ( + "baseline_errors", "decoder_errors", "residual_weights", "weight_bucket_stats", + "agreement_count", "unavailable_decoders" + ): + self.assertIn(key, self._result) + + def test_missing_engine_decoder_errors_all_base_decoders_present(self): + from evaluation.failure_analysis import DECODER_NAMES + self.assertTrue(set(DECODER_NAMES).issubset(set(self._result["decoder_errors"].keys()))) + + def test_missing_engine_sample_count_correct(self): + self.assertEqual(len(self._result["residual_weights"]), self._N) + + +class TestDecoderAblationStudyOnnxExport(unittest.TestCase): + """ + ONNX_WORKFLOW=1 must attempt ONNX export (rank 0) then fall back to PyTorch for inference. + Results must be identical in structure to the default PyTorch path. + """ + + _D = 3 + _T = 3 + _N = 8 + + def test_workflow1_exports_and_uses_pytorch(self): + from evaluation.failure_analysis import decoder_ablation_study, DECODER_NAMES + from data.datapipe_stim import QCDataPipePreDecoder_Memory_inference + real_ds = QCDataPipePreDecoder_Memory_inference( + distance=self._D, + n_rounds=self._T, + num_samples=self._N, + error_mode="circuit_level_surface_custom", + p_error=0.01, + measure_basis="X", + code_rotation="XV", + ) + exported = [] + + def _fake_onnx_export(module, *args, **kwargs): + exported.append(kwargs.get("f") or (args[1] if len(args) > 1 else None)) + + with tempfile.TemporaryDirectory() as tmpdir: + cfg = _make_cfg( + tmpdir, distance=self._D, n_rounds=self._T, basis="X", n_samples=self._N + ) + cfg.test.n_rounds = self._T + with patch("data.factory.DatapipeFactory") as mf, \ + patch.dict("os.environ", {"ONNX_WORKFLOW": "1"}), \ + patch("torch.onnx.export", side_effect=_fake_onnx_export), \ + patch("os.getcwd", return_value=tmpdir): + mf.create_datapipe_inference.return_value = real_ds + result = decoder_ablation_study( + _ZeroModel(), torch.device("cpu"), _DummyDist(), cfg + ) + + # ONNX export was attempted + self.assertEqual(len(exported), 1, "Expected exactly one torch.onnx.export call") + # Inference fell back to PyTorch (no TRT context) — same result structure + self.assertEqual(result["total_samples"], self._N) + self.assertTrue(set(DECODER_NAMES).issubset(set(result["decoder_errors"].keys()))) + + def test_workflow1_export_failure_falls_back_gracefully(self): + """If ONNX export raises, results must still be valid (PyTorch fallback).""" + from evaluation.failure_analysis import decoder_ablation_study, DECODER_NAMES + from data.datapipe_stim import QCDataPipePreDecoder_Memory_inference + real_ds = QCDataPipePreDecoder_Memory_inference( + distance=self._D, + n_rounds=self._T, + num_samples=self._N, + error_mode="circuit_level_surface_custom", + p_error=0.01, + measure_basis="X", + code_rotation="XV", + ) + with tempfile.TemporaryDirectory() as tmpdir: + cfg = _make_cfg( + tmpdir, distance=self._D, n_rounds=self._T, basis="X", n_samples=self._N + ) + cfg.test.n_rounds = self._T + with patch("data.factory.DatapipeFactory") as mf, \ + patch.dict("os.environ", {"ONNX_WORKFLOW": "1"}), \ + patch("torch.onnx.export", side_effect=RuntimeError("export broken")), \ + patch("os.getcwd", return_value=tmpdir): + mf.create_datapipe_inference.return_value = real_ds + result = decoder_ablation_study( + _ZeroModel(), torch.device("cpu"), _DummyDist(), cfg + ) + self.assertEqual(result["total_samples"], self._N) + self.assertTrue(set(DECODER_NAMES).issubset(set(result["decoder_errors"].keys()))) + + +class TestDecoderAblationStudyTRTExecution(unittest.TestCase): + """ + Full mock TRT execution path: inject a mock tensorrt module and a CPU-compatible + mock CUDA device so the TRT code path runs end-to-end without a physical GPU. + + Verifies that: + - trt_context is activated when ONNX_WORKFLOW=3 and the engine file exists + - L_and_residual_dets from the TRT context is parsed into pre_L and residual + - The rest of the batch loop (global decoders, stats) runs identically to PyTorch path + - Results have the correct structure and sample count + """ + + _D = 3 + _T = 3 + _N = 8 + + @classmethod + def setUpClass(cls): + from evaluation.failure_analysis import DECODER_NAMES + cls._decoder_names = DECODER_NAMES + cls._result_x = cls._run_once("X") + cls._result_z = cls._run_once("Z") + + @classmethod + def _run_once(cls, basis="X"): + from evaluation.failure_analysis import decoder_ablation_study + from data.datapipe_stim import QCDataPipePreDecoder_Memory_inference + real_ds = QCDataPipePreDecoder_Memory_inference( + distance=cls._D, + n_rounds=cls._T, + num_samples=cls._N, + error_mode="circuit_level_surface_custom", + p_error=0.01, + measure_basis=basis, + code_rotation="XV", + ) + circuit = real_ds.circ.stim_circuit + det_model = circuit.detector_error_model( + decompose_errors=True, approximate_disjoint_errors=True + ) + num_detectors = det_model.num_detectors + mock_trt = _make_mock_trt_module(num_detectors) + mock_device = _MockCUDADevice() + + with tempfile.TemporaryDirectory() as tmpdir: + cfg = _make_cfg(tmpdir, distance=cls._D, n_rounds=cls._T, basis=basis, n_samples=cls._N) + cfg.test.n_rounds = cls._T + # Create a dummy engine file so ONNX_WORKFLOW=3 finds it + engine_path = str( + Path(tmpdir) / f"predecoder_memory_d{cls._D}_T{cls._T}_{basis}.engine" + ) + with open(engine_path, "wb") as _f: + _f.write(b"dummy_engine") + + with patch("data.factory.DatapipeFactory") as mf, \ + patch.dict("os.environ", {"ONNX_WORKFLOW": "3"}), \ + patch.dict("sys.modules", {"tensorrt": mock_trt}), \ + patch("os.getcwd", return_value=tmpdir), \ + _patch_tensor_to_for_mock_cuda(): + mf.create_datapipe_inference.return_value = real_ds + result = decoder_ablation_study(_ZeroModel(), mock_device, _DummyDist(), cfg) + return result + + def test_trt_path_returns_correct_sample_count(self): + self.assertEqual(self._result_x["total_samples"], self._N) + + def test_trt_path_result_has_all_required_keys(self): + for key in ( + "baseline_errors", "decoder_errors", "residual_weights", "weight_bucket_stats", + "agreement_count", "unavailable_decoders" + ): + self.assertIn(key, self._result_x) + + def test_trt_path_base_decoders_present(self): + self.assertTrue( + set(self._decoder_names).issubset(set(self._result_x["decoder_errors"].keys())) + ) + + def test_trt_path_residual_weights_length_matches_sample_count(self): + self.assertEqual(len(self._result_x["residual_weights"]), self._N) + + def test_trt_path_decoder_error_counts_are_non_negative(self): + for name in self._decoder_names: + with self.subTest(decoder=name): + self.assertGreaterEqual(self._result_x["decoder_errors"][name], 0) + self.assertLessEqual(self._result_x["decoder_errors"][name], self._N) + + def test_trt_path_z_basis_also_works(self): + self.assertEqual(self._result_z["total_samples"], self._N) + + def test_trt_path_agreement_count_within_bounds(self): + self.assertGreaterEqual(self._result_x["agreement_count"], 0) + self.assertLessEqual(self._result_x["agreement_count"], self._N) + + +class TestDecoderAblationStudyExportAndBuildTRT(unittest.TestCase): + """ + ONNX_WORKFLOW=2 (EXPORT_AND_USE_TRT): mock both torch.onnx.export and + tensorrt so the full export → engine-build → TRT-inference path runs + end-to-end without a GPU or a real ONNX model. + """ + + _D = 3 + _T = 3 + _N = 8 + + @classmethod + def setUpClass(cls): + from evaluation.failure_analysis import DECODER_NAMES + cls._decoder_names = DECODER_NAMES + cls._result = cls._run_once("X") + + @classmethod + def _run_once(cls, basis="X"): + from evaluation.failure_analysis import decoder_ablation_study + from data.datapipe_stim import QCDataPipePreDecoder_Memory_inference + real_ds = QCDataPipePreDecoder_Memory_inference( + distance=cls._D, + n_rounds=cls._T, + num_samples=cls._N, + error_mode="circuit_level_surface_custom", + p_error=0.01, + measure_basis=basis, + code_rotation="XV", + ) + circuit = real_ds.circ.stim_circuit + det_model = circuit.detector_error_model( + decompose_errors=True, approximate_disjoint_errors=True + ) + mock_trt = _make_mock_trt_module(det_model.num_detectors) + mock_device = _MockCUDADevice() + + with tempfile.TemporaryDirectory() as tmpdir: + cfg = _make_cfg(tmpdir, distance=cls._D, n_rounds=cls._T, basis=basis, n_samples=cls._N) + cfg.test.n_rounds = cls._T + + def _fake_onnx_export(module, *args, **kwargs): + # Write an empty placeholder so the TRT parser can open the file. + f = kwargs.get("f") or (args[1] if len(args) > 1 else None) + if f: + Path(f).touch() + + with patch("data.factory.DatapipeFactory") as mf, \ + patch.dict("os.environ", {"ONNX_WORKFLOW": "2"}), \ + patch.dict("sys.modules", {"tensorrt": mock_trt}), \ + patch("os.getcwd", return_value=tmpdir), \ + patch("torch.onnx.export", side_effect=_fake_onnx_export), \ + _patch_tensor_to_for_mock_cuda(): + mf.create_datapipe_inference.return_value = real_ds + result = decoder_ablation_study(_ZeroModel(), mock_device, _DummyDist(), cfg) + return result + + def test_export_and_build_returns_correct_sample_count(self): + self.assertEqual(self._result["total_samples"], self._N) + + def test_export_and_build_result_has_all_required_keys(self): + for key in ( + "baseline_errors", "decoder_errors", "residual_weights", "weight_bucket_stats", + "agreement_count", "unavailable_decoders" + ): + self.assertIn(key, self._result) + + def test_export_and_build_all_decoders_present(self): + self.assertTrue( + set(self._decoder_names).issubset(set(self._result["decoder_errors"].keys())) + ) + + def test_export_and_build_residual_weights_length(self): + self.assertEqual(len(self._result["residual_weights"]), self._N) + + if __name__ == "__main__": unittest.main()