From 4dd200b4240b27031fb7fa749014f13f39b79505 Mon Sep 17 00:00:00 2001 From: HaozheZhang6 Date: Thu, 11 Jun 2026 16:56:22 +0000 Subject: [PATCH 1/3] [bugfix] add optional MEDIA_DECODE_TIMEOUT to prevent silent media-decode hang Corrupt or unsupported audio/video clips can make the native decoders (librosa->audioread/ffmpeg, decord) deadlock in C while holding the GIL, which silently hangs a DataLoader worker forever with GPUs idle and no error logged. Add an opt-in hard wall-clock timeout: when MEDIA_DECODE_TIMEOUT (seconds) > 0, the decode runs in a forked worker that is killed on overrun and raises TimeoutError so one bad clip cannot freeze the whole run. Default (unset/0) keeps the original in-process path with zero overhead. Applied to load_audio (the reproduced hang path); the helper is reusable for the decord video paths. Fixes #9507 --- swift/template/vision_utils.py | 47 +++++++++++++++++++++++++++++++- tests/utils/test_vision_utils.py | 42 ++++++++++++++++++++++++++++ 2 files changed, 88 insertions(+), 1 deletion(-) create mode 100644 tests/utils/test_vision_utils.py diff --git a/swift/template/vision_utils.py b/swift/template/vision_utils.py index 9b828f6765..661fd2c597 100644 --- a/swift/template/vision_utils.py +++ b/swift/template/vision_utils.py @@ -296,7 +296,47 @@ def uniform_sample(_l, _n): return frames -def load_audio(audio: Union[str, bytes], sampling_rate: int, return_sr: bool = False): +def _decode_worker(queue, func, args, kwargs): + try: + queue.put((True, func(*args, **kwargs))) + except Exception as e: + queue.put((False, e)) + + +def _decode_with_timeout(func: Callable[..., _T], *args, **kwargs) -> _T: + # Native media decoders (audioread/ffmpeg, decord) can deadlock in C while holding the GIL on a + # corrupt/unsupported clip, silently hanging a DataLoader worker forever; a signal-based timeout + # can't interrupt them. When `MEDIA_DECODE_TIMEOUT` (seconds) > 0, decode in a killable subprocess. + timeout = get_env_args('media_decode_timeout', float, 0) + if not timeout or timeout <= 0: + return func(*args, **kwargs) + import multiprocessing as mp + + # Fork the decode worker: load_audio runs inside the data pipeline where fork is already the + # norm (PyTorch DataLoader), and unlike forkserver/spawn it does not re-import the training + # entrypoint per call. Fall back to the default context where fork is unavailable. + try: + ctx = mp.get_context('fork') + except ValueError: + ctx = mp.get_context() + queue = ctx.SimpleQueue() + process = ctx.Process(target=_decode_worker, args=(queue, func, args, kwargs)) + process.start() + process.join(timeout) + if process.is_alive(): + process.terminate() + process.join() + raise TimeoutError(f'Media decode exceeded MEDIA_DECODE_TIMEOUT={timeout}s and was killed ' + '(likely a corrupt or unsupported clip).') + if process.exitcode != 0: + raise RuntimeError(f'Media decode subprocess exited abnormally (exitcode={process.exitcode}).') + ok, payload = queue.get() + if not ok: + raise payload + return payload + + +def _load_audio(audio: Union[str, bytes], sampling_rate: int): import librosa try: audio_io = load_file(audio) @@ -308,6 +348,11 @@ def load_audio(audio: Union[str, bytes], sampling_rate: int, return_sr: bool = F else: audio_io = _check_path(audio) or audio res = librosa.load(audio_io, sr=sampling_rate) + return res + + +def load_audio(audio: Union[str, bytes], sampling_rate: int, return_sr: bool = False): + res = _decode_with_timeout(_load_audio, audio, sampling_rate) return res if return_sr else res[0] diff --git a/tests/utils/test_vision_utils.py b/tests/utils/test_vision_utils.py new file mode 100644 index 0000000000..af4ccc8991 --- /dev/null +++ b/tests/utils/test_vision_utils.py @@ -0,0 +1,42 @@ +import pytest +import time + +from swift.template.vision_utils import _decode_with_timeout + + +def _sleep_forever(*args, **kwargs): + time.sleep(3600) + + +def _echo(value): + return value + + +def _raise_value_error(*args, **kwargs): + raise ValueError('boom') + + +def test_decode_with_timeout_kills_hung_decode(monkeypatch): + # A decode that never returns must be killed and surface a TimeoutError rather than hang. + monkeypatch.setenv('MEDIA_DECODE_TIMEOUT', '2') + start = time.time() + with pytest.raises(TimeoutError): + _decode_with_timeout(_sleep_forever) + assert time.time() - start < 30 + + +def test_decode_with_timeout_returns_result_when_fast(monkeypatch): + monkeypatch.setenv('MEDIA_DECODE_TIMEOUT', '10') + assert _decode_with_timeout(_echo, 'ok') == 'ok' + + +def test_decode_with_timeout_propagates_decode_error(monkeypatch): + monkeypatch.setenv('MEDIA_DECODE_TIMEOUT', '10') + with pytest.raises(ValueError, match='boom'): + _decode_with_timeout(_raise_value_error) + + +def test_decode_with_timeout_disabled_calls_directly(monkeypatch): + # Default (unset / 0): no subprocess, original behavior and zero overhead. + monkeypatch.delenv('MEDIA_DECODE_TIMEOUT', raising=False) + assert _decode_with_timeout(_echo, 'direct') == 'direct' From acf7068acb60d240efab435f6cd63777d1e908d9 Mon Sep 17 00:00:00 2001 From: Haozhe Zhang Date: Mon, 15 Jun 2026 17:54:32 -0700 Subject: [PATCH 2/3] [bugfix] use Pipe+poll for media decode timeout to avoid large-payload deadlock The previous SimpleQueue-based implementation is backed by an OS pipe, so the worker's put() blocks once the decoded payload exceeds the pipe buffer (~64KB) until the parent reads -- but the parent waited in join(), deadlocking and false-timing-out every real audio/video clip. Switch to a Pipe + poll(timeout): poll() waits for the decode to finish (the worker sends only after decoding), then recv() drains the payload, unblocking the worker as it writes. Add a 1MB-payload regression test. --- swift/template/vision_utils.py | 41 +++++++++++++++++++++----------- tests/utils/test_vision_utils.py | 13 ++++++++++ 2 files changed, 40 insertions(+), 14 deletions(-) diff --git a/swift/template/vision_utils.py b/swift/template/vision_utils.py index 661fd2c597..362547eac0 100644 --- a/swift/template/vision_utils.py +++ b/swift/template/vision_utils.py @@ -296,11 +296,13 @@ def uniform_sample(_l, _n): return frames -def _decode_worker(queue, func, args, kwargs): +def _decode_worker(conn, func, args, kwargs): try: - queue.put((True, func(*args, **kwargs))) + conn.send((True, func(*args, **kwargs))) except Exception as e: - queue.put((False, e)) + conn.send((False, e)) + finally: + conn.close() def _decode_with_timeout(func: Callable[..., _T], *args, **kwargs) -> _T: @@ -319,18 +321,29 @@ def _decode_with_timeout(func: Callable[..., _T], *args, **kwargs) -> _T: ctx = mp.get_context('fork') except ValueError: ctx = mp.get_context() - queue = ctx.SimpleQueue() - process = ctx.Process(target=_decode_worker, args=(queue, func, args, kwargs)) + # Use a Pipe + poll() rather than a (Simple)Queue: a Queue is backed by an OS pipe, so the + # worker's send blocks once the decoded payload exceeds the pipe buffer (~64KB) until the parent + # reads -- but the parent would be in join(), so every real audio/video clip would deadlock and + # false-timeout. poll(timeout) waits for the decode to finish (the worker sends only after + # decoding); recv() then drains the payload, unblocking the worker as it writes. + recv_conn, send_conn = ctx.Pipe(duplex=False) + process = ctx.Process(target=_decode_worker, args=(send_conn, func, args, kwargs)) process.start() - process.join(timeout) - if process.is_alive(): - process.terminate() - process.join() - raise TimeoutError(f'Media decode exceeded MEDIA_DECODE_TIMEOUT={timeout}s and was killed ' - '(likely a corrupt or unsupported clip).') - if process.exitcode != 0: - raise RuntimeError(f'Media decode subprocess exited abnormally (exitcode={process.exitcode}).') - ok, payload = queue.get() + send_conn.close() # parent only reads; closing its copy lets recv() see EOF if the worker dies + try: + if not recv_conn.poll(timeout): + process.terminate() + process.join() + raise TimeoutError(f'Media decode exceeded MEDIA_DECODE_TIMEOUT={timeout}s and was killed ' + '(likely a corrupt or unsupported clip).') + try: + ok, payload = recv_conn.recv() + except EOFError: + process.join() + raise RuntimeError(f'Media decode subprocess exited abnormally (exitcode={process.exitcode}).') + finally: + recv_conn.close() + process.join() if not ok: raise payload return payload diff --git a/tests/utils/test_vision_utils.py b/tests/utils/test_vision_utils.py index af4ccc8991..136e5497dc 100644 --- a/tests/utils/test_vision_utils.py +++ b/tests/utils/test_vision_utils.py @@ -16,6 +16,12 @@ def _raise_value_error(*args, **kwargs): raise ValueError('boom') +def _big_payload(*args, **kwargs): + # ~1 MB, far above the ~64 KB OS pipe buffer that made a SimpleQueue worker block its put() + # while the parent waited in join() -- a deadlock that false-timed-out every real media clip. + return b'x' * (1024 * 1024) + + def test_decode_with_timeout_kills_hung_decode(monkeypatch): # A decode that never returns must be killed and surface a TimeoutError rather than hang. monkeypatch.setenv('MEDIA_DECODE_TIMEOUT', '2') @@ -40,3 +46,10 @@ def test_decode_with_timeout_disabled_calls_directly(monkeypatch): # Default (unset / 0): no subprocess, original behavior and zero overhead. monkeypatch.delenv('MEDIA_DECODE_TIMEOUT', raising=False) assert _decode_with_timeout(_echo, 'direct') == 'direct' + + +def test_decode_with_timeout_handles_large_payload(monkeypatch): + # A decoded payload larger than the OS pipe buffer must transfer without deadlocking the + # worker (regression: the previous SimpleQueue implementation false-timed-out here). + monkeypatch.setenv('MEDIA_DECODE_TIMEOUT', '10') + assert _decode_with_timeout(_big_payload) == b'x' * (1024 * 1024) From 00e90927aee7259d265bb3e640c245dbccc1b60c Mon Sep 17 00:00:00 2001 From: Haozhe Zhang Date: Mon, 15 Jun 2026 22:39:52 -0700 Subject: [PATCH 3/3] test: rewrite media-decode-timeout test as unittest.TestCase ms-swift's CI runs tests via unittest (tests/run.py), not pytest, so the pytest-based module failed to import (ModuleNotFoundError: pytest) and its module-level test functions were never discovered. Port to unittest.TestCase with setUp/tearDown env handling and self.assertRaises; no behavior change. --- tests/utils/test_vision_utils.py | 60 +++++++++++++++++++------------- 1 file changed, 36 insertions(+), 24 deletions(-) diff --git a/tests/utils/test_vision_utils.py b/tests/utils/test_vision_utils.py index 136e5497dc..c48f883ce4 100644 --- a/tests/utils/test_vision_utils.py +++ b/tests/utils/test_vision_utils.py @@ -1,5 +1,6 @@ -import pytest +import os import time +import unittest from swift.template.vision_utils import _decode_with_timeout @@ -22,34 +23,45 @@ def _big_payload(*args, **kwargs): return b'x' * (1024 * 1024) -def test_decode_with_timeout_kills_hung_decode(monkeypatch): - # A decode that never returns must be killed and surface a TimeoutError rather than hang. - monkeypatch.setenv('MEDIA_DECODE_TIMEOUT', '2') - start = time.time() - with pytest.raises(TimeoutError): - _decode_with_timeout(_sleep_forever) - assert time.time() - start < 30 +class TestDecodeWithTimeout(unittest.TestCase): + def setUp(self): + self._saved_timeout = os.environ.get('MEDIA_DECODE_TIMEOUT') -def test_decode_with_timeout_returns_result_when_fast(monkeypatch): - monkeypatch.setenv('MEDIA_DECODE_TIMEOUT', '10') - assert _decode_with_timeout(_echo, 'ok') == 'ok' + def tearDown(self): + if self._saved_timeout is None: + os.environ.pop('MEDIA_DECODE_TIMEOUT', None) + else: + os.environ['MEDIA_DECODE_TIMEOUT'] = self._saved_timeout + def test_kills_hung_decode(self): + # A decode that never returns must be killed and surface a TimeoutError rather than hang. + os.environ['MEDIA_DECODE_TIMEOUT'] = '2' + start = time.time() + with self.assertRaises(TimeoutError): + _decode_with_timeout(_sleep_forever) + self.assertLess(time.time() - start, 30) -def test_decode_with_timeout_propagates_decode_error(monkeypatch): - monkeypatch.setenv('MEDIA_DECODE_TIMEOUT', '10') - with pytest.raises(ValueError, match='boom'): - _decode_with_timeout(_raise_value_error) + def test_returns_result_when_fast(self): + os.environ['MEDIA_DECODE_TIMEOUT'] = '10' + self.assertEqual(_decode_with_timeout(_echo, 'ok'), 'ok') + def test_propagates_decode_error(self): + os.environ['MEDIA_DECODE_TIMEOUT'] = '10' + with self.assertRaises(ValueError): + _decode_with_timeout(_raise_value_error) -def test_decode_with_timeout_disabled_calls_directly(monkeypatch): - # Default (unset / 0): no subprocess, original behavior and zero overhead. - monkeypatch.delenv('MEDIA_DECODE_TIMEOUT', raising=False) - assert _decode_with_timeout(_echo, 'direct') == 'direct' + def test_disabled_calls_directly(self): + # Default (unset / 0): no subprocess, original behavior and zero overhead. + os.environ.pop('MEDIA_DECODE_TIMEOUT', None) + self.assertEqual(_decode_with_timeout(_echo, 'direct'), 'direct') + def test_handles_large_payload(self): + # A decoded payload larger than the OS pipe buffer must transfer without deadlocking the + # worker (regression: the previous SimpleQueue implementation false-timed-out here). + os.environ['MEDIA_DECODE_TIMEOUT'] = '10' + self.assertEqual(_decode_with_timeout(_big_payload), b'x' * (1024 * 1024)) -def test_decode_with_timeout_handles_large_payload(monkeypatch): - # A decoded payload larger than the OS pipe buffer must transfer without deadlocking the - # worker (regression: the previous SimpleQueue implementation false-timed-out here). - monkeypatch.setenv('MEDIA_DECODE_TIMEOUT', '10') - assert _decode_with_timeout(_big_payload) == b'x' * (1024 * 1024) + +if __name__ == '__main__': + unittest.main()