diff --git a/swift/model/models/qwen.py b/swift/model/models/qwen.py index 5ba254e41d..32e4610459 100644 --- a/swift/model/models/qwen.py +++ b/swift/model/models/qwen.py @@ -1814,10 +1814,25 @@ def get_model(self, model_dir: str, *args, **kwargs) -> PreTrainedModel: class Qwen2AudioLoader(ModelLoader): + @staticmethod + def _is_transformers5() -> bool: + import transformers + return version.parse(transformers.__version__) >= version.parse('5.0.0') + + def _patch_transformers5_model(self, model: PreTrainedModel) -> PreTrainedModel: + if not self._is_transformers5(): + return model + generation_config = getattr(model, 'generation_config', None) + if generation_config is not None and hasattr(generation_config, 'cache_implementation'): + generation_config.cache_implementation = None + _patch_hybrid_cache_device_update() + return model + def get_model(self, model_dir: str, *args, **kwargs) -> PreTrainedModel: from transformers import Qwen2AudioForConditionalGeneration self.auto_model_cls = self.auto_model_cls or Qwen2AudioForConditionalGeneration - return super().get_model(model_dir, *args, **kwargs) + model = super().get_model(model_dir, *args, **kwargs) + return self._patch_transformers5_model(model) register_model( @@ -1832,11 +1847,28 @@ def get_model(self, model_dir: str, *args, **kwargs) -> PreTrainedModel: Qwen2AudioLoader, model_arch=ModelArch.qwen2_audio, architectures=['Qwen2AudioForConditionalGeneration'], - requires=['transformers>=4.45,<4.49', 'librosa'], + requires=['transformers>=4.48', 'librosa'], tags=['audio'], )) +def _patch_hybrid_cache_device_update() -> None: + try: + from transformers.cache_utils import HybridCache + + def update(self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int, *args, + **kwargs) -> Tuple[torch.Tensor]: + self.key_cache[layer_idx] = self.key_cache[layer_idx].to(key_states.device) + self.value_cache[layer_idx] = self.value_cache[layer_idx].to(value_states.device) + return self._update_origin(key_states, value_states, layer_idx, *args, **kwargs) + + if not hasattr(HybridCache, '_update_origin'): + HybridCache._update_origin = HybridCache.update + HybridCache.update = update + except ImportError: + pass + + class OvisLoader(ModelLoader): def get_processor(self, model_dir, config) -> Processor: diff --git a/swift/template/base.py b/swift/template/base.py index 350e5d9534..fb6d2b2a5d 100644 --- a/swift/template/base.py +++ b/swift/template/base.py @@ -8,6 +8,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +import warnings from collections import defaultdict from contextlib import contextmanager, nullcontext from copy import deepcopy @@ -1061,9 +1062,12 @@ def _add_default_tags(inputs: StdTemplateInputs): f'num_media: {num_media}, num_media_tags: {num_media_tags}, total_content: {total_content}. ' 'We will only replace the frontmost media_tags while keeping the subsequent media_tags.') - def _encode_context_list(self, - context_list: List[Context], - loss_scale_list: Optional[List[float]] = None) -> Tuple[List[int], List[int], List[float]]: + def _encode_context_list( + self, + context_list: List[Context], + loss_scale_list: Optional[List[float]] = None, + audio_path_list: Optional[List[str]] = None, + ) -> Tuple[List[int], List[int], List[float]]: is_binary_loss_scale = self.is_binary_loss_scale if is_binary_loss_scale is None: is_binary_loss_scale = self.loss_scale.is_binary_loss_scale @@ -1072,13 +1076,35 @@ def _encode_context_list(self, loss_scale: List[float] = [] if loss_scale_list is None: loss_scale_list = [0.] * len(context_list) - for i, (context, loss_weight) in enumerate(zip(context_list, loss_scale_list)): - if isinstance(context, str): - token_list = self._tokenize(context) + + audio_ptr = 0 + for context, loss_weight in zip(context_list, loss_scale_list): + if (isinstance(context, str) and '<|AUDIO|>' in context and getattr(self.tokenizer, 'model_meta', None) + and getattr(self.tokenizer.model_meta, 'model_type', None) == 'qwen2_audio'): + if audio_path_list is None or audio_ptr >= len(audio_path_list): + warnings.warn('Found <|AUDIO|> but no matching audio input; fallback to text tokenization', + RuntimeWarning) + token_list = self._tokenize(context) + else: + import librosa + sample_rate = self.processor.feature_extractor.sampling_rate + wav, _ = librosa.load(audio_path_list[audio_ptr], sr=sample_rate, mono=True) + encoded = self.processor( + text=context, + audio=wav, + sampling_rate=sample_rate, + return_tensors=None, + add_special_tokens=False, + ) + token_list = encoded['input_ids'] + if len(token_list) > 0 and isinstance(token_list[0], list): + token_list = token_list[0] + audio_ptr += 1 else: - token_list = context + token_list = self._tokenize(context) if isinstance(context, str) else context + input_ids += token_list - if loss_scale_list[i] > 0.0: + if loss_weight > 0.0: labels += token_list else: labels += [-100] * len(token_list) @@ -1470,7 +1496,11 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: loss_scale = encoded['prompt_loss_scale'] + encoded['answer_loss_scale'] else: res_context_list, loss_scale_list = self._simplify_context_list(res_context_list, loss_scale_list, inputs) - input_ids, labels, loss_scale = self._encode_context_list(res_context_list, loss_scale_list) + if self.tokenizer.model_meta.model_type and self.tokenizer.model_meta.model_type == 'qwen2_audio': + input_ids, labels, loss_scale = self._encode_context_list(res_context_list, loss_scale_list, + inputs.audios) + else: + input_ids, labels, loss_scale = self._encode_context_list(res_context_list, loss_scale_list) self._add_dynamic_eos(input_ids, labels, loss_scale, self._encode_context_list(self.template_meta.suffix)[0]) encoded['input_ids'] = input_ids