diff --git a/swift/template/vision_utils.py b/swift/template/vision_utils.py index 9b828f6765..20ae927847 100644 --- a/swift/template/vision_utils.py +++ b/swift/template/vision_utils.py @@ -103,6 +103,10 @@ def rescale_image(img: Image.Image, max_pixels: int) -> Image.Image: def _check_path(path: str) -> Union[str, None]: """If it is a path, return the string; if it is base64, return None.""" + if not isinstance(path, str): + # bytes audio/image data is not a path; let the caller fall back to it + # instead of crashing on the str-only checks below (e.g. startswith). + return None MAX_PATH_HEURISTIC = 2000 if len(path) > MAX_PATH_HEURISTIC: return @@ -302,7 +306,7 @@ def load_audio(audio: Union[str, bytes], sampling_rate: int, return_sr: bool = F audio_io = load_file(audio) res = librosa.load(audio_io, sr=sampling_rate) except Exception: - if audio.startswith(('http://', 'https://')): + if isinstance(audio, str) and audio.startswith(('http://', 'https://')): import audioread audio_io = audioread.ffdec.FFmpegAudioFile(audio) else: diff --git a/tests/general/test_template.py b/tests/general/test_template.py index da6270ae69..5d472e24e6 100644 --- a/tests/general/test_template.py +++ b/tests/general/test_template.py @@ -77,3 +77,40 @@ def test_mllm_dataset_map(): test_mllm() test_llm_dataset_map() test_mllm_dataset_map() + + +def test_load_audio_bytes_input_does_not_crash_on_fallback(monkeypatch): + import sys + import types + + from swift.template import vision_utils + + calls = [] + + fake_librosa = types.ModuleType('librosa') + + def fake_load(audio_io, sr): + calls.append(audio_io) + if len(calls) == 1: + # First attempt fails (e.g. a format soundfile can't read), forcing + # the except branch that used to call bytes.startswith and crash. + raise RuntimeError('first load fails') + return ([0.1, 0.2], sr) + + fake_librosa.load = fake_load + monkeypatch.setitem(sys.modules, 'librosa', fake_librosa) + + # bytes audio (allowed by the Union[str, bytes] signature) must not raise a + # TypeError from `audio.startswith(...)` or from `_check_path(bytes)` when + # the first decode fails and the except branch runs. + result = vision_utils.load_audio(b'\x00\x01raw-audio-bytes', sampling_rate=16000) + + assert result == [0.1, 0.2] + + +def test_check_path_with_bytes_returns_none(): + from swift.template.vision_utils import _check_path + + # bytes input is not a path; it must return None instead of raising a + # TypeError from the str-only checks (len/os.path/startswith) below. + assert _check_path(b'\x00\x01raw-bytes') is None