diff --git a/swift/template/vision_utils.py b/swift/template/vision_utils.py index 9b828f6765..362547eac0 100644 --- a/swift/template/vision_utils.py +++ b/swift/template/vision_utils.py @@ -296,7 +296,60 @@ def uniform_sample(_l, _n): return frames -def load_audio(audio: Union[str, bytes], sampling_rate: int, return_sr: bool = False): +def _decode_worker(conn, func, args, kwargs): + try: + conn.send((True, func(*args, **kwargs))) + except Exception as e: + conn.send((False, e)) + finally: + conn.close() + + +def _decode_with_timeout(func: Callable[..., _T], *args, **kwargs) -> _T: + # Native media decoders (audioread/ffmpeg, decord) can deadlock in C while holding the GIL on a + # corrupt/unsupported clip, silently hanging a DataLoader worker forever; a signal-based timeout + # can't interrupt them. When `MEDIA_DECODE_TIMEOUT` (seconds) > 0, decode in a killable subprocess. + timeout = get_env_args('media_decode_timeout', float, 0) + if not timeout or timeout <= 0: + return func(*args, **kwargs) + import multiprocessing as mp + + # Fork the decode worker: load_audio runs inside the data pipeline where fork is already the + # norm (PyTorch DataLoader), and unlike forkserver/spawn it does not re-import the training + # entrypoint per call. Fall back to the default context where fork is unavailable. + try: + ctx = mp.get_context('fork') + except ValueError: + ctx = mp.get_context() + # Use a Pipe + poll() rather than a (Simple)Queue: a Queue is backed by an OS pipe, so the + # worker's send blocks once the decoded payload exceeds the pipe buffer (~64KB) until the parent + # reads -- but the parent would be in join(), so every real audio/video clip would deadlock and + # false-timeout. poll(timeout) waits for the decode to finish (the worker sends only after + # decoding); recv() then drains the payload, unblocking the worker as it writes. + recv_conn, send_conn = ctx.Pipe(duplex=False) + process = ctx.Process(target=_decode_worker, args=(send_conn, func, args, kwargs)) + process.start() + send_conn.close() # parent only reads; closing its copy lets recv() see EOF if the worker dies + try: + if not recv_conn.poll(timeout): + process.terminate() + process.join() + raise TimeoutError(f'Media decode exceeded MEDIA_DECODE_TIMEOUT={timeout}s and was killed ' + '(likely a corrupt or unsupported clip).') + try: + ok, payload = recv_conn.recv() + except EOFError: + process.join() + raise RuntimeError(f'Media decode subprocess exited abnormally (exitcode={process.exitcode}).') + finally: + recv_conn.close() + process.join() + if not ok: + raise payload + return payload + + +def _load_audio(audio: Union[str, bytes], sampling_rate: int): import librosa try: audio_io = load_file(audio) @@ -308,6 +361,11 @@ def load_audio(audio: Union[str, bytes], sampling_rate: int, return_sr: bool = F else: audio_io = _check_path(audio) or audio res = librosa.load(audio_io, sr=sampling_rate) + return res + + +def load_audio(audio: Union[str, bytes], sampling_rate: int, return_sr: bool = False): + res = _decode_with_timeout(_load_audio, audio, sampling_rate) return res if return_sr else res[0] diff --git a/tests/utils/test_vision_utils.py b/tests/utils/test_vision_utils.py new file mode 100644 index 0000000000..c48f883ce4 --- /dev/null +++ b/tests/utils/test_vision_utils.py @@ -0,0 +1,67 @@ +import os +import time +import unittest + +from swift.template.vision_utils import _decode_with_timeout + + +def _sleep_forever(*args, **kwargs): + time.sleep(3600) + + +def _echo(value): + return value + + +def _raise_value_error(*args, **kwargs): + raise ValueError('boom') + + +def _big_payload(*args, **kwargs): + # ~1 MB, far above the ~64 KB OS pipe buffer that made a SimpleQueue worker block its put() + # while the parent waited in join() -- a deadlock that false-timed-out every real media clip. + return b'x' * (1024 * 1024) + + +class TestDecodeWithTimeout(unittest.TestCase): + + def setUp(self): + self._saved_timeout = os.environ.get('MEDIA_DECODE_TIMEOUT') + + def tearDown(self): + if self._saved_timeout is None: + os.environ.pop('MEDIA_DECODE_TIMEOUT', None) + else: + os.environ['MEDIA_DECODE_TIMEOUT'] = self._saved_timeout + + def test_kills_hung_decode(self): + # A decode that never returns must be killed and surface a TimeoutError rather than hang. + os.environ['MEDIA_DECODE_TIMEOUT'] = '2' + start = time.time() + with self.assertRaises(TimeoutError): + _decode_with_timeout(_sleep_forever) + self.assertLess(time.time() - start, 30) + + def test_returns_result_when_fast(self): + os.environ['MEDIA_DECODE_TIMEOUT'] = '10' + self.assertEqual(_decode_with_timeout(_echo, 'ok'), 'ok') + + def test_propagates_decode_error(self): + os.environ['MEDIA_DECODE_TIMEOUT'] = '10' + with self.assertRaises(ValueError): + _decode_with_timeout(_raise_value_error) + + def test_disabled_calls_directly(self): + # Default (unset / 0): no subprocess, original behavior and zero overhead. + os.environ.pop('MEDIA_DECODE_TIMEOUT', None) + self.assertEqual(_decode_with_timeout(_echo, 'direct'), 'direct') + + def test_handles_large_payload(self): + # A decoded payload larger than the OS pipe buffer must transfer without deadlocking the + # worker (regression: the previous SimpleQueue implementation false-timed-out here). + os.environ['MEDIA_DECODE_TIMEOUT'] = '10' + self.assertEqual(_decode_with_timeout(_big_payload), b'x' * (1024 * 1024)) + + +if __name__ == '__main__': + unittest.main()