Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 59 additions & 1 deletion swift/template/vision_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,60 @@ def uniform_sample(_l, _n):
return frames


def load_audio(audio: Union[str, bytes], sampling_rate: int, return_sr: bool = False):
def _decode_worker(conn, func, args, kwargs):
try:
conn.send((True, func(*args, **kwargs)))
except Exception as e:
conn.send((False, e))
finally:
conn.close()


def _decode_with_timeout(func: Callable[..., _T], *args, **kwargs) -> _T:
# Native media decoders (audioread/ffmpeg, decord) can deadlock in C while holding the GIL on a
# corrupt/unsupported clip, silently hanging a DataLoader worker forever; a signal-based timeout
# can't interrupt them. When `MEDIA_DECODE_TIMEOUT` (seconds) > 0, decode in a killable subprocess.
timeout = get_env_args('media_decode_timeout', float, 0)
if not timeout or timeout <= 0:
return func(*args, **kwargs)
import multiprocessing as mp

# Fork the decode worker: load_audio runs inside the data pipeline where fork is already the
# norm (PyTorch DataLoader), and unlike forkserver/spawn it does not re-import the training
# entrypoint per call. Fall back to the default context where fork is unavailable.
try:
ctx = mp.get_context('fork')
except ValueError:
ctx = mp.get_context()
# Use a Pipe + poll() rather than a (Simple)Queue: a Queue is backed by an OS pipe, so the
# worker's send blocks once the decoded payload exceeds the pipe buffer (~64KB) until the parent
# reads -- but the parent would be in join(), so every real audio/video clip would deadlock and
# false-timeout. poll(timeout) waits for the decode to finish (the worker sends only after
# decoding); recv() then drains the payload, unblocking the worker as it writes.
recv_conn, send_conn = ctx.Pipe(duplex=False)
process = ctx.Process(target=_decode_worker, args=(send_conn, func, args, kwargs))
process.start()
send_conn.close() # parent only reads; closing its copy lets recv() see EOF if the worker dies
try:
if not recv_conn.poll(timeout):
process.terminate()
process.join()
raise TimeoutError(f'Media decode exceeded MEDIA_DECODE_TIMEOUT={timeout}s and was killed '
'(likely a corrupt or unsupported clip).')
try:
ok, payload = recv_conn.recv()
except EOFError:
process.join()
raise RuntimeError(f'Media decode subprocess exited abnormally (exitcode={process.exitcode}).')
finally:
recv_conn.close()
process.join()
if not ok:
raise payload
return payload


def _load_audio(audio: Union[str, bytes], sampling_rate: int):
import librosa
try:
audio_io = load_file(audio)
Expand All @@ -308,6 +361,11 @@ def load_audio(audio: Union[str, bytes], sampling_rate: int, return_sr: bool = F
else:
audio_io = _check_path(audio) or audio
res = librosa.load(audio_io, sr=sampling_rate)
return res


def load_audio(audio: Union[str, bytes], sampling_rate: int, return_sr: bool = False):
res = _decode_with_timeout(_load_audio, audio, sampling_rate)
return res if return_sr else res[0]


Expand Down
67 changes: 67 additions & 0 deletions tests/utils/test_vision_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import os
import time
import unittest

from swift.template.vision_utils import _decode_with_timeout


def _sleep_forever(*args, **kwargs):
time.sleep(3600)


def _echo(value):
return value


def _raise_value_error(*args, **kwargs):
raise ValueError('boom')


def _big_payload(*args, **kwargs):
# ~1 MB, far above the ~64 KB OS pipe buffer that made a SimpleQueue worker block its put()
# while the parent waited in join() -- a deadlock that false-timed-out every real media clip.
return b'x' * (1024 * 1024)


class TestDecodeWithTimeout(unittest.TestCase):

def setUp(self):
self._saved_timeout = os.environ.get('MEDIA_DECODE_TIMEOUT')

def tearDown(self):
if self._saved_timeout is None:
os.environ.pop('MEDIA_DECODE_TIMEOUT', None)
else:
os.environ['MEDIA_DECODE_TIMEOUT'] = self._saved_timeout

def test_kills_hung_decode(self):
# A decode that never returns must be killed and surface a TimeoutError rather than hang.
os.environ['MEDIA_DECODE_TIMEOUT'] = '2'
start = time.time()
with self.assertRaises(TimeoutError):
_decode_with_timeout(_sleep_forever)
self.assertLess(time.time() - start, 30)

def test_returns_result_when_fast(self):
os.environ['MEDIA_DECODE_TIMEOUT'] = '10'
self.assertEqual(_decode_with_timeout(_echo, 'ok'), 'ok')

def test_propagates_decode_error(self):
os.environ['MEDIA_DECODE_TIMEOUT'] = '10'
with self.assertRaises(ValueError):
_decode_with_timeout(_raise_value_error)

def test_disabled_calls_directly(self):
# Default (unset / 0): no subprocess, original behavior and zero overhead.
os.environ.pop('MEDIA_DECODE_TIMEOUT', None)
self.assertEqual(_decode_with_timeout(_echo, 'direct'), 'direct')

def test_handles_large_payload(self):
# A decoded payload larger than the OS pipe buffer must transfer without deadlocking the
# worker (regression: the previous SimpleQueue implementation false-timed-out here).
os.environ['MEDIA_DECODE_TIMEOUT'] = '10'
self.assertEqual(_decode_with_timeout(_big_payload), b'x' * (1024 * 1024))


if __name__ == '__main__':
unittest.main()
Loading