Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 34 additions & 2 deletions swift/model/models/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1814,10 +1814,25 @@ def get_model(self, model_dir: str, *args, **kwargs) -> PreTrainedModel:

class Qwen2AudioLoader(ModelLoader):

@staticmethod
def _is_transformers5() -> bool:
import transformers
return version.parse(transformers.__version__) >= version.parse('5.0.0')
Comment on lines +1817 to +1820

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The transformers module is not imported at the top of swift/model/models/qwen.py. Calling _is_transformers5() will raise a NameError at runtime. Please import transformers locally inside the method or at the module level.

Suggested change
@staticmethod
def _is_transformers5() -> bool:
return version.parse(transformers.__version__) >= version.parse('5.0.0')
@staticmethod
def _is_transformers5() -> bool:
import transformers
return version.parse(transformers.__version__) >= version.parse('5.0.0')


def _patch_transformers5_model(self, model: PreTrainedModel) -> PreTrainedModel:
if not self._is_transformers5():
return model
generation_config = getattr(model, 'generation_config', None)
if generation_config is not None and hasattr(generation_config, 'cache_implementation'):
generation_config.cache_implementation = None
_patch_hybrid_cache_device_update()
return model

def get_model(self, model_dir: str, *args, **kwargs) -> PreTrainedModel:
from transformers import Qwen2AudioForConditionalGeneration
self.auto_model_cls = self.auto_model_cls or Qwen2AudioForConditionalGeneration
return super().get_model(model_dir, *args, **kwargs)
model = super().get_model(model_dir, *args, **kwargs)
return self._patch_transformers5_model(model)


register_model(
Expand All @@ -1832,11 +1847,28 @@ def get_model(self, model_dir: str, *args, **kwargs) -> PreTrainedModel:
Qwen2AudioLoader,
model_arch=ModelArch.qwen2_audio,
architectures=['Qwen2AudioForConditionalGeneration'],
requires=['transformers>=4.45,<4.49', 'librosa'],
requires=['transformers>=4.48', 'librosa'],
tags=['audio'],
))


def _patch_hybrid_cache_device_update() -> None:
try:
from transformers.cache_utils import HybridCache

def update(self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int, *args,
**kwargs) -> Tuple[torch.Tensor]:
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(key_states.device)
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(value_states.device)
return self._update_origin(key_states, value_states, layer_idx, *args, **kwargs)

if not hasattr(HybridCache, '_update_origin'):
HybridCache._update_origin = HybridCache.update
HybridCache.update = update
except ImportError:
pass


class OvisLoader(ModelLoader):

def get_processor(self, model_dir, config) -> Processor:
Expand Down
48 changes: 39 additions & 9 deletions swift/template/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import warnings
from collections import defaultdict
from contextlib import contextmanager, nullcontext
from copy import deepcopy
Expand Down Expand Up @@ -1061,9 +1062,12 @@ def _add_default_tags(inputs: StdTemplateInputs):
f'num_media: {num_media}, num_media_tags: {num_media_tags}, total_content: {total_content}. '
'We will only replace the frontmost media_tags while keeping the subsequent media_tags.')

def _encode_context_list(self,
context_list: List[Context],
loss_scale_list: Optional[List[float]] = None) -> Tuple[List[int], List[int], List[float]]:
def _encode_context_list(
self,
context_list: List[Context],
loss_scale_list: Optional[List[float]] = None,
audio_path_list: Optional[List[str]] = None,
) -> Tuple[List[int], List[int], List[float]]:
is_binary_loss_scale = self.is_binary_loss_scale
if is_binary_loss_scale is None:
is_binary_loss_scale = self.loss_scale.is_binary_loss_scale
Expand All @@ -1072,13 +1076,35 @@ def _encode_context_list(self,
loss_scale: List[float] = []
if loss_scale_list is None:
loss_scale_list = [0.] * len(context_list)
for i, (context, loss_weight) in enumerate(zip(context_list, loss_scale_list)):
if isinstance(context, str):
token_list = self._tokenize(context)

audio_ptr = 0
for context, loss_weight in zip(context_list, loss_scale_list):
if (isinstance(context, str) and '<|AUDIO|>' in context and getattr(self.tokenizer, 'model_meta', None)
and getattr(self.tokenizer.model_meta, 'model_type', None) == 'qwen2_audio'):
if audio_path_list is None or audio_ptr >= len(audio_path_list):
warnings.warn('Found <|AUDIO|> but no matching audio input; fallback to text tokenization',
RuntimeWarning)
token_list = self._tokenize(context)
else:
import librosa
sample_rate = self.processor.feature_extractor.sampling_rate
wav, _ = librosa.load(audio_path_list[audio_ptr], sr=sample_rate, mono=True)
encoded = self.processor(
text=context,
audio=wav,
sampling_rate=sample_rate,
return_tensors=None,
add_special_tokens=False,
)
token_list = encoded['input_ids']
if len(token_list) > 0 and isinstance(token_list[0], list):
token_list = token_list[0]
audio_ptr += 1
Comment on lines +1081 to +1102

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To prevent other models from accidentally triggering the audio processing path (which would raise an AttributeError since non-audio models do not have a feature_extractor on their processor), we should restrict this block to only run when the model type is qwen2_audio. Additionally, we should import librosa locally here to avoid making it a global hard dependency.

        for context, loss_weight in zip(context_list, loss_scale_list):
            if (isinstance(context, str)
                    and '<|AUDIO|>' in context
                    and getattr(self.tokenizer, 'model_meta', None)
                    and getattr(self.tokenizer.model_meta, 'model_type', None) == 'qwen2_audio'):
                if audio_path_list is None or audio_ptr >= len(audio_path_list):
                    warnings.warn('Found <|AUDIO|> but no matching audio input; fallback to text tokenization',
                                  RuntimeWarning)
                    token_list = self._tokenize(context)
                else:
                    import librosa
                    sample_rate = self.processor.feature_extractor.sampling_rate
                    wav, _ = librosa.load(audio_path_list[audio_ptr], sr=sample_rate, mono=True)
                    encoded = self.processor(
                        text=context,
                        audio=wav,
                        sampling_rate=sample_rate,
                        return_tensors=None,
                        add_special_tokens=False,
                    )
                    token_list = encoded['input_ids']
                    if len(token_list) > 0 and isinstance(token_list[0], list):
                        token_list = token_list[0]
                    audio_ptr += 1

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review. I updated the PR to:

  1. import transformers locally in _is_transformers5;
  2. remove the global librosa import and import it lazily in the Qwen2-Audio branch;
  3. restrict the audio placeholder branch to qwen2_audio only.

else:
token_list = context
token_list = self._tokenize(context) if isinstance(context, str) else context

input_ids += token_list
if loss_scale_list[i] > 0.0:
if loss_weight > 0.0:
labels += token_list
else:
labels += [-100] * len(token_list)
Expand Down Expand Up @@ -1470,7 +1496,11 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
loss_scale = encoded['prompt_loss_scale'] + encoded['answer_loss_scale']
else:
res_context_list, loss_scale_list = self._simplify_context_list(res_context_list, loss_scale_list, inputs)
input_ids, labels, loss_scale = self._encode_context_list(res_context_list, loss_scale_list)
if self.tokenizer.model_meta.model_type and self.tokenizer.model_meta.model_type == 'qwen2_audio':
input_ids, labels, loss_scale = self._encode_context_list(res_context_list, loss_scale_list,
inputs.audios)
else:
input_ids, labels, loss_scale = self._encode_context_list(res_context_list, loss_scale_list)
self._add_dynamic_eos(input_ids, labels, loss_scale, self._encode_context_list(self.template_meta.suffix)[0])

encoded['input_ids'] = input_ids
Expand Down