-
Notifications
You must be signed in to change notification settings - Fork 1.5k
[compat] Support Qwen2-Audio with newer transformers #9453
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,6 +8,7 @@ | |
| import torch | ||
| import torch.nn as nn | ||
| import torch.nn.functional as F | ||
| import warnings | ||
| from collections import defaultdict | ||
| from contextlib import contextmanager, nullcontext | ||
| from copy import deepcopy | ||
|
|
@@ -1061,9 +1062,12 @@ def _add_default_tags(inputs: StdTemplateInputs): | |
| f'num_media: {num_media}, num_media_tags: {num_media_tags}, total_content: {total_content}. ' | ||
| 'We will only replace the frontmost media_tags while keeping the subsequent media_tags.') | ||
|
|
||
| def _encode_context_list(self, | ||
| context_list: List[Context], | ||
| loss_scale_list: Optional[List[float]] = None) -> Tuple[List[int], List[int], List[float]]: | ||
| def _encode_context_list( | ||
| self, | ||
| context_list: List[Context], | ||
| loss_scale_list: Optional[List[float]] = None, | ||
| audio_path_list: Optional[List[str]] = None, | ||
| ) -> Tuple[List[int], List[int], List[float]]: | ||
| is_binary_loss_scale = self.is_binary_loss_scale | ||
| if is_binary_loss_scale is None: | ||
| is_binary_loss_scale = self.loss_scale.is_binary_loss_scale | ||
|
|
@@ -1072,13 +1076,35 @@ def _encode_context_list(self, | |
| loss_scale: List[float] = [] | ||
| if loss_scale_list is None: | ||
| loss_scale_list = [0.] * len(context_list) | ||
| for i, (context, loss_weight) in enumerate(zip(context_list, loss_scale_list)): | ||
| if isinstance(context, str): | ||
| token_list = self._tokenize(context) | ||
|
|
||
| audio_ptr = 0 | ||
| for context, loss_weight in zip(context_list, loss_scale_list): | ||
| if (isinstance(context, str) and '<|AUDIO|>' in context and getattr(self.tokenizer, 'model_meta', None) | ||
| and getattr(self.tokenizer.model_meta, 'model_type', None) == 'qwen2_audio'): | ||
| if audio_path_list is None or audio_ptr >= len(audio_path_list): | ||
| warnings.warn('Found <|AUDIO|> but no matching audio input; fallback to text tokenization', | ||
| RuntimeWarning) | ||
| token_list = self._tokenize(context) | ||
| else: | ||
| import librosa | ||
| sample_rate = self.processor.feature_extractor.sampling_rate | ||
| wav, _ = librosa.load(audio_path_list[audio_ptr], sr=sample_rate, mono=True) | ||
| encoded = self.processor( | ||
| text=context, | ||
| audio=wav, | ||
| sampling_rate=sample_rate, | ||
| return_tensors=None, | ||
| add_special_tokens=False, | ||
| ) | ||
| token_list = encoded['input_ids'] | ||
| if len(token_list) > 0 and isinstance(token_list[0], list): | ||
| token_list = token_list[0] | ||
| audio_ptr += 1 | ||
|
Comment on lines
+1081
to
+1102
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To prevent other models from accidentally triggering the audio processing path (which would raise an for context, loss_weight in zip(context_list, loss_scale_list):
if (isinstance(context, str)
and '<|AUDIO|>' in context
and getattr(self.tokenizer, 'model_meta', None)
and getattr(self.tokenizer.model_meta, 'model_type', None) == 'qwen2_audio'):
if audio_path_list is None or audio_ptr >= len(audio_path_list):
warnings.warn('Found <|AUDIO|> but no matching audio input; fallback to text tokenization',
RuntimeWarning)
token_list = self._tokenize(context)
else:
import librosa
sample_rate = self.processor.feature_extractor.sampling_rate
wav, _ = librosa.load(audio_path_list[audio_ptr], sr=sample_rate, mono=True)
encoded = self.processor(
text=context,
audio=wav,
sampling_rate=sample_rate,
return_tensors=None,
add_special_tokens=False,
)
token_list = encoded['input_ids']
if len(token_list) > 0 and isinstance(token_list[0], list):
token_list = token_list[0]
audio_ptr += 1
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the review. I updated the PR to:
|
||
| else: | ||
| token_list = context | ||
| token_list = self._tokenize(context) if isinstance(context, str) else context | ||
|
|
||
| input_ids += token_list | ||
| if loss_scale_list[i] > 0.0: | ||
| if loss_weight > 0.0: | ||
| labels += token_list | ||
| else: | ||
| labels += [-100] * len(token_list) | ||
|
|
@@ -1470,7 +1496,11 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: | |
| loss_scale = encoded['prompt_loss_scale'] + encoded['answer_loss_scale'] | ||
| else: | ||
| res_context_list, loss_scale_list = self._simplify_context_list(res_context_list, loss_scale_list, inputs) | ||
| input_ids, labels, loss_scale = self._encode_context_list(res_context_list, loss_scale_list) | ||
| if self.tokenizer.model_meta.model_type and self.tokenizer.model_meta.model_type == 'qwen2_audio': | ||
| input_ids, labels, loss_scale = self._encode_context_list(res_context_list, loss_scale_list, | ||
| inputs.audios) | ||
| else: | ||
| input_ids, labels, loss_scale = self._encode_context_list(res_context_list, loss_scale_list) | ||
| self._add_dynamic_eos(input_ids, labels, loss_scale, self._encode_context_list(self.template_meta.suffix)[0]) | ||
|
|
||
| encoded['input_ids'] = input_ids | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
transformersmodule is not imported at the top ofswift/model/models/qwen.py. Calling_is_transformers5()will raise aNameErrorat runtime. Please importtransformerslocally inside the method or at the module level.