From e82f11372ba68c7c39531ecf166f020f1245d3e7 Mon Sep 17 00:00:00 2001 From: MWXGOD <1491377079@qq.com> Date: Sat, 30 May 2026 19:28:31 +0800 Subject: [PATCH 1/2] Support Qwen2-Audio with newer transformers --- swift/model/models/qwen.py | 35 ++++++++++++++++++++++++++-- swift/template/base.py | 47 ++++++++++++++++++++++++++++++-------- 2 files changed, 71 insertions(+), 11 deletions(-) diff --git a/swift/model/models/qwen.py b/swift/model/models/qwen.py index 5ba254e41d..de84633d47 100644 --- a/swift/model/models/qwen.py +++ b/swift/model/models/qwen.py @@ -1814,10 +1814,24 @@ def get_model(self, model_dir: str, *args, **kwargs) -> PreTrainedModel: class Qwen2AudioLoader(ModelLoader): + @staticmethod + def _is_transformers5() -> bool: + return version.parse(transformers.__version__) >= version.parse('5.0.0') + + def _patch_transformers5_model(self, model: PreTrainedModel) -> PreTrainedModel: + if not self._is_transformers5(): + return model + generation_config = getattr(model, 'generation_config', None) + if generation_config is not None and hasattr(generation_config, 'cache_implementation'): + generation_config.cache_implementation = None + _patch_hybrid_cache_device_update() + return model + def get_model(self, model_dir: str, *args, **kwargs) -> PreTrainedModel: from transformers import Qwen2AudioForConditionalGeneration self.auto_model_cls = self.auto_model_cls or Qwen2AudioForConditionalGeneration - return super().get_model(model_dir, *args, **kwargs) + model = super().get_model(model_dir, *args, **kwargs) + return self._patch_transformers5_model(model) register_model( @@ -1832,11 +1846,28 @@ def get_model(self, model_dir: str, *args, **kwargs) -> PreTrainedModel: Qwen2AudioLoader, model_arch=ModelArch.qwen2_audio, architectures=['Qwen2AudioForConditionalGeneration'], - requires=['transformers>=4.45,<4.49', 'librosa'], + requires=['transformers>=4.48', 'librosa'], tags=['audio'], )) +def _patch_hybrid_cache_device_update() -> None: + try: + from transformers.cache_utils import HybridCache + + def update(self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int, *args, + **kwargs) -> Tuple[torch.Tensor]: + self.key_cache[layer_idx] = self.key_cache[layer_idx].to(key_states.device) + self.value_cache[layer_idx] = self.value_cache[layer_idx].to(value_states.device) + return self._update_origin(key_states, value_states, layer_idx, *args, **kwargs) + + if not hasattr(HybridCache, '_update_origin'): + HybridCache._update_origin = HybridCache.update + HybridCache.update = update + except ImportError: + pass + + class OvisLoader(ModelLoader): def get_processor(self, model_dir, config) -> Processor: diff --git a/swift/template/base.py b/swift/template/base.py index 350e5d9534..ded0fbe585 100644 --- a/swift/template/base.py +++ b/swift/template/base.py @@ -1,6 +1,7 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import hashlib import inspect +import librosa import math import os import random @@ -8,6 +9,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +import warnings from collections import defaultdict from contextlib import contextmanager, nullcontext from copy import deepcopy @@ -1061,9 +1063,12 @@ def _add_default_tags(inputs: StdTemplateInputs): f'num_media: {num_media}, num_media_tags: {num_media_tags}, total_content: {total_content}. ' 'We will only replace the frontmost media_tags while keeping the subsequent media_tags.') - def _encode_context_list(self, - context_list: List[Context], - loss_scale_list: Optional[List[float]] = None) -> Tuple[List[int], List[int], List[float]]: + def _encode_context_list( + self, + context_list: List[Context], + loss_scale_list: Optional[List[float]] = None, + audio_path_list: Optional[List[str]] = None, + ) -> Tuple[List[int], List[int], List[float]]: is_binary_loss_scale = self.is_binary_loss_scale if is_binary_loss_scale is None: is_binary_loss_scale = self.loss_scale.is_binary_loss_scale @@ -1072,13 +1077,33 @@ def _encode_context_list(self, loss_scale: List[float] = [] if loss_scale_list is None: loss_scale_list = [0.] * len(context_list) - for i, (context, loss_weight) in enumerate(zip(context_list, loss_scale_list)): - if isinstance(context, str): - token_list = self._tokenize(context) + + audio_ptr = 0 + for context, loss_weight in zip(context_list, loss_scale_list): + if isinstance(context, str) and '<|AUDIO|>' in context: + if audio_path_list is None or audio_ptr >= len(audio_path_list): + warnings.warn('Found <|AUDIO|> but no matching audio input; fallback to text tokenization', + RuntimeWarning) + token_list = self._tokenize(context) + else: + sample_rate = self.processor.feature_extractor.sampling_rate + wav, _ = librosa.load(audio_path_list[audio_ptr], sr=sample_rate, mono=True) + encoded = self.processor( + text=context, + audio=wav, + sampling_rate=sample_rate, + return_tensors=None, + add_special_tokens=False, + ) + token_list = encoded['input_ids'] + if len(token_list) > 0 and isinstance(token_list[0], list): + token_list = token_list[0] + audio_ptr += 1 else: - token_list = context + token_list = self._tokenize(context) if isinstance(context, str) else context + input_ids += token_list - if loss_scale_list[i] > 0.0: + if loss_weight > 0.0: labels += token_list else: labels += [-100] * len(token_list) @@ -1470,7 +1495,11 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: loss_scale = encoded['prompt_loss_scale'] + encoded['answer_loss_scale'] else: res_context_list, loss_scale_list = self._simplify_context_list(res_context_list, loss_scale_list, inputs) - input_ids, labels, loss_scale = self._encode_context_list(res_context_list, loss_scale_list) + if self.tokenizer.model_meta.model_type and self.tokenizer.model_meta.model_type == 'qwen2_audio': + input_ids, labels, loss_scale = self._encode_context_list(res_context_list, loss_scale_list, + inputs.audios) + else: + input_ids, labels, loss_scale = self._encode_context_list(res_context_list, loss_scale_list) self._add_dynamic_eos(input_ids, labels, loss_scale, self._encode_context_list(self.template_meta.suffix)[0]) encoded['input_ids'] = input_ids From 4e55dbeb09a88ec1d2209b71ca10531d0f1574f2 Mon Sep 17 00:00:00 2001 From: MWXGOD <1491377079@qq.com> Date: Sat, 30 May 2026 19:38:56 +0800 Subject: [PATCH 2/2] Address Qwen2-Audio compatibility review comments --- swift/model/models/qwen.py | 1 + swift/template/base.py | 5 +++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/swift/model/models/qwen.py b/swift/model/models/qwen.py index de84633d47..32e4610459 100644 --- a/swift/model/models/qwen.py +++ b/swift/model/models/qwen.py @@ -1816,6 +1816,7 @@ class Qwen2AudioLoader(ModelLoader): @staticmethod def _is_transformers5() -> bool: + import transformers return version.parse(transformers.__version__) >= version.parse('5.0.0') def _patch_transformers5_model(self, model: PreTrainedModel) -> PreTrainedModel: diff --git a/swift/template/base.py b/swift/template/base.py index ded0fbe585..fb6d2b2a5d 100644 --- a/swift/template/base.py +++ b/swift/template/base.py @@ -1,7 +1,6 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import hashlib import inspect -import librosa import math import os import random @@ -1080,12 +1079,14 @@ def _encode_context_list( audio_ptr = 0 for context, loss_weight in zip(context_list, loss_scale_list): - if isinstance(context, str) and '<|AUDIO|>' in context: + if (isinstance(context, str) and '<|AUDIO|>' in context and getattr(self.tokenizer, 'model_meta', None) + and getattr(self.tokenizer.model_meta, 'model_type', None) == 'qwen2_audio'): if audio_path_list is None or audio_ptr >= len(audio_path_list): warnings.warn('Found <|AUDIO|> but no matching audio input; fallback to text tokenization', RuntimeWarning) token_list = self._tokenize(context) else: + import librosa sample_rate = self.processor.feature_extractor.sampling_rate wav, _ = librosa.load(audio_path_list[audio_ptr], sr=sample_rate, mono=True) encoded = self.processor(