Skip to content

Commit 2db609e

Browse files
committed
ci: exercise multi-worker DataLoader (num_workers > 0) in CI
Production defaults use num_workers=4 with spawn multiprocessing, but all CI jobs and tests forced num_workers=0. Add coverage for both layers: - New test (test_dataloader_multiprocessing.py): verifies the Stim inference datapipe is pickle-safe and produces correct results with num_workers=2 across X, Z, and mixed bases. Runs on CPU in a dedicated ci.yml job. - New ci-gpu.yml step: re-runs inference with PREDECODER_INFERENCE_NUM_WORKERS=2 after the existing smoke run, exercising the full logical_error_rate.py pipeline (multi-worker DataLoader → model forward → PyMatching → LER check). Signed-off-by: kvmto <kmato@nvidia.com>
1 parent 993e797 commit 2db609e

3 files changed

Lines changed: 121 additions & 0 deletions

File tree

.github/workflows/ci-gpu.yml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,21 @@ jobs:
9696
PREDECODER_TEST_SAMPLES: "2048"
9797
PREDECODER_TRAIN_EPOCHS: "2"
9898

99+
- name: Training + inference with multi-worker DataLoader (num_workers=2)
100+
shell: bash
101+
run: |
102+
source .venv_train_${{ matrix.python-version }}/bin/activate
103+
bash code/scripts/smoke_run.sh 2>&1 | tee /tmp/ci_multiworker.log
104+
r=${PIPESTATUS[0]}; [ $r -ne 0 ] && exit $r
105+
python code/scripts/check_ler_from_log.py /tmp/ci_multiworker.log --max-ler 0.35
106+
env:
107+
EXPERIMENT_NAME: ci_multiworker
108+
PREDECODER_TRAIN_SAMPLES: "16384"
109+
PREDECODER_VAL_SAMPLES: "2048"
110+
PREDECODER_TEST_SAMPLES: "2048"
111+
PREDECODER_TRAIN_EPOCHS: "2"
112+
PREDECODER_INFERENCE_NUM_WORKERS: "2"
113+
99114
# ---------------------------------------------------------------------------
100115
# Mid-tier (~5-10 min): extended training + inference with LER check.
101116
# Runs only after merge to main (not on PR branches) to save GPU time.

.github/workflows/ci.yml

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,30 @@ jobs:
7171
SKIP_TESTS: "0"
7272
PIP_EXTRA_INDEX_URL: "https://download.pytorch.org/whl/cpu"
7373

74+
# ---------------------------------------------------------------------------
75+
# Multi-worker DataLoader: verifies the Stim inference datapipe works with
76+
# num_workers > 0 (spawn multiprocessing context), matching the production
77+
# default of num_workers=4 but never exercised in other CI jobs.
78+
# ---------------------------------------------------------------------------
79+
multiprocessing-dataloader:
80+
runs-on: linux-amd64-cpu4
81+
steps:
82+
- uses: actions/checkout@v4
83+
with:
84+
lfs: true
85+
- uses: actions/setup-python@v5
86+
with:
87+
python-version: "3.12"
88+
- name: Install dependencies
89+
run: |
90+
python -m pip install --upgrade pip setuptools wheel
91+
pip install -r code/requirements_public_inference.txt \
92+
--extra-index-url https://download.pytorch.org/whl/cpu
93+
- name: Run multi-worker DataLoader tests
94+
run: >
95+
PYTHONPATH=code python -m unittest discover
96+
-s code/tests -p "test_dataloader_multiprocessing.py" -v
97+
7498
unit-tests-coverage:
7599
runs-on: linux-amd64-cpu4
76100
steps:
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3+
#
4+
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5+
# property and proprietary rights in and to this material, related
6+
# documentation and any modifications thereto. Any use, reproduction,
7+
# disclosure or distribution of this material and related documentation
8+
# without an express license agreement from NVIDIA CORPORATION or
9+
# its affiliates is strictly prohibited.
10+
"""
11+
Multi-worker DataLoader tests for the Stim inference datapipe.
12+
13+
Verifies QCDataPipePreDecoder_Memory_inference is pickle-safe and correct
14+
under num_workers > 0 with spawn multiprocessing (CPU-only, no GPU needed).
15+
"""
16+
17+
import sys
18+
import unittest
19+
from pathlib import Path
20+
21+
_repo_code = Path(__file__).resolve().parent.parent
22+
if str(_repo_code) not in sys.path:
23+
sys.path.insert(0, str(_repo_code))
24+
25+
import torch
26+
from torch.utils.data import DataLoader
27+
28+
from data.datapipe_stim import QCDataPipePreDecoder_Memory_inference
29+
30+
_D, _T, _N, _BS, _W = 3, 3, 32, 8, 2
31+
32+
33+
def _make_loader(basis, num_workers=_W, **kw):
34+
ds = QCDataPipePreDecoder_Memory_inference(
35+
distance=_D,
36+
n_rounds=_T,
37+
num_samples=_N,
38+
error_mode="circuit_level_surface_custom",
39+
p_error=0.01,
40+
measure_basis=basis,
41+
code_rotation="XV",
42+
)
43+
opts = dict(batch_size=_BS, shuffle=False)
44+
if num_workers > 0:
45+
opts["multiprocessing_context"] = "spawn"
46+
opts.update(kw)
47+
return ds, DataLoader(ds, num_workers=num_workers, **opts)
48+
49+
50+
class TestMultiWorkerDataLoader(unittest.TestCase):
51+
52+
def test_iteration_completes_all_bases(self):
53+
for basis in ("X", "Z", "both"):
54+
with self.subTest(basis=basis):
55+
_, loader = _make_loader(basis)
56+
total = sum(b["trainX"].shape[0] for b in loader)
57+
self.assertEqual(total, _N)
58+
59+
def test_matches_single_worker_all_bases(self):
60+
for basis in ("X", "Z", "both"):
61+
with self.subTest(basis=basis):
62+
ds, _ = _make_loader(basis, num_workers=0)
63+
loader_0 = DataLoader(ds, batch_size=_BS, shuffle=False)
64+
loader_n = DataLoader(
65+
ds,
66+
batch_size=_BS,
67+
shuffle=False,
68+
num_workers=_W,
69+
multiprocessing_context="spawn",
70+
)
71+
for b0, bn in zip(loader_0, loader_n):
72+
for k in ("trainX", "x_syn_diff", "z_syn_diff", "dets_and_obs"):
73+
torch.testing.assert_close(b0[k], bn[k])
74+
75+
def test_persistent_workers_with_prefetch(self):
76+
_, loader = _make_loader("X", persistent_workers=True, prefetch_factor=2)
77+
total = sum(b["trainX"].shape[0] for b in loader)
78+
self.assertEqual(total, _N)
79+
80+
81+
if __name__ == "__main__":
82+
unittest.main()

0 commit comments

Comments
 (0)