@@ -211,5 +211,62 @@ def test_exact_threshold_qualifies(self):
211211 )
212212
213213
214+ # ---------------------------------------------------------------------------
215+ # torch.compile + DataLoader spawn-worker guard
216+ # ---------------------------------------------------------------------------
217+ class TestCompileSpawnGuard (unittest .TestCase ):
218+ """Guard that forces num_workers=0 when torch.compile is active.
219+
220+ torch.compile + DataLoader multiprocessing_context=spawn segfaults (leaked
221+ semaphores, core dump). Both LER inference and SDR functions apply this
222+ guard before constructing the DataLoader.
223+ """
224+
225+ def _apply_ler_guard (self , applied_compile : bool , num_workers : int ) -> dict :
226+ """Mirror the guard in run_inference_and_decode_pre_decoder_memory."""
227+ kwargs = {"num_workers" : num_workers , "prefetch_factor" : 2 }
228+ if applied_compile and int (kwargs .get ("num_workers" , 0 )) > 0 :
229+ kwargs ["num_workers" ] = 0
230+ if kwargs .get ("num_workers" , 0 ) == 0 :
231+ kwargs .pop ("prefetch_factor" , None )
232+ return kwargs
233+
234+ def _apply_sdr_guard (self , compile_env : str , num_workers : int ) -> dict :
235+ """Mirror the guard in compute_syndrome_density_reduction."""
236+ kwargs = {"num_workers" : num_workers }
237+ with patch .dict (os .environ , {"PREDECODER_TORCH_COMPILE" : compile_env }):
238+ if _get_env_bool ("PREDECODER_TORCH_COMPILE" , True ) and int (kwargs .get ("num_workers" , 0 )) > 0 :
239+ kwargs ["num_workers" ] = 0
240+ return kwargs
241+
242+ # --- LER inference guard ---
243+
244+ def test_ler_guard_forces_zero_workers_when_compiled (self ):
245+ kwargs = self ._apply_ler_guard (applied_compile = True , num_workers = 16 )
246+ self .assertEqual (kwargs ["num_workers" ], 0 )
247+
248+ def test_ler_guard_keeps_workers_when_not_compiled (self ):
249+ kwargs = self ._apply_ler_guard (applied_compile = False , num_workers = 16 )
250+ self .assertEqual (kwargs ["num_workers" ], 16 )
251+
252+ def test_ler_guard_noop_when_workers_already_zero (self ):
253+ kwargs = self ._apply_ler_guard (applied_compile = True , num_workers = 0 )
254+ self .assertEqual (kwargs ["num_workers" ], 0 )
255+
256+ def test_ler_guard_drops_prefetch_factor_when_workers_zeroed (self ):
257+ kwargs = self ._apply_ler_guard (applied_compile = True , num_workers = 16 )
258+ self .assertNotIn ("prefetch_factor" , kwargs )
259+
260+ # --- SDR guard (env-var based) ---
261+
262+ def test_sdr_guard_forces_zero_workers_when_compile_env_on (self ):
263+ kwargs = self ._apply_sdr_guard (compile_env = "1" , num_workers = 16 )
264+ self .assertEqual (kwargs ["num_workers" ], 0 )
265+
266+ def test_sdr_guard_keeps_workers_when_compile_env_off (self ):
267+ kwargs = self ._apply_sdr_guard (compile_env = "0" , num_workers = 16 )
268+ self .assertEqual (kwargs ["num_workers" ], 16 )
269+
270+
214271if __name__ == "__main__" :
215272 unittest .main ()
0 commit comments