|
43 | 43 | import torch |
44 | 44 |
|
45 | 45 | _compile_disabled = os.environ.get("PREDECODER_TORCH_COMPILE", "").strip().lower() in ( |
46 | | - "0", "false", "no", "off", |
| 46 | + "0", |
| 47 | + "false", |
| 48 | + "no", |
| 49 | + "off", |
47 | 50 | ) |
48 | 51 |
|
49 | 52 |
|
@@ -823,7 +826,9 @@ def _get_compiled_seq_wr(num_layers: int): |
823 | 826 | def _wr_fn(error_f, padded_masks, is_boundary, layer_valid): |
824 | 827 | return _wr_seq_step_nobreak(error_f, padded_masks, is_boundary, layer_valid, nl) |
825 | 828 |
|
826 | | - compiled = torch.compile(_wr_fn, mode="reduce-overhead", fullgraph=True, disable=_compile_disabled) |
| 829 | + compiled = torch.compile( |
| 830 | + _wr_fn, mode="reduce-overhead", fullgraph=True, disable=_compile_disabled |
| 831 | + ) |
827 | 832 | _compiled_seq_wr_cache[key] = compiled |
828 | 833 | return compiled |
829 | 834 |
|
@@ -2221,7 +2226,9 @@ def _timelike_loop( |
2221 | 2226 |
|
2222 | 2227 | return x_work, z_work, sz_work, sx_work |
2223 | 2228 |
|
2224 | | - compiled = torch.compile(_timelike_loop, mode="reduce-overhead", fullgraph=True, disable=_compile_disabled) |
| 2229 | + compiled = torch.compile( |
| 2230 | + _timelike_loop, mode="reduce-overhead", fullgraph=True, disable=_compile_disabled |
| 2231 | + ) |
2225 | 2232 | _compiled_timelike_cache[key] = compiled |
2226 | 2233 | return compiled |
2227 | 2234 |
|
@@ -2289,7 +2296,9 @@ def _w2_loop( |
2289 | 2296 |
|
2290 | 2297 | return x_work, z_work, sz_work, sx_work |
2291 | 2298 |
|
2292 | | - compiled = torch.compile(_w2_loop, mode="max-autotune", fullgraph=True, disable=_compile_disabled) |
| 2299 | + compiled = torch.compile( |
| 2300 | + _w2_loop, mode="max-autotune", fullgraph=True, disable=_compile_disabled |
| 2301 | + ) |
2293 | 2302 | _compiled_weight2_cache[key] = compiled |
2294 | 2303 | return compiled |
2295 | 2304 |
|
|
0 commit comments