Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support return_tensors in audio chat templates #34601

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
42 changes: 42 additions & 0 deletions src/transformers/audio_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,52 @@
and remove unnecessary dependencies.
"""

import os
import warnings
from io import BytesIO
from typing import List, Optional, Tuple, Union

import numpy as np
import requests

from .utils import is_librosa_available, requires_backends


if is_librosa_available():
import librosa


def load_audio(audio: Union[str, np.ndarray], sampling_rate=16000, timeout=None) -> np.ndarray:
"""
Loads `audio` to an np.ndarray object.

Args:
audio (`str` or `np.ndarray`):
The audio to be laoded to the numpy array format.
sampling_rate (`int`, *optional*, defaults to 16000):
The samlping rate to be used when loading the audio. It should be same as the
sampling rate the model you will be using further was trained with.
timeout (`float`, *optional*):
The timeout value in seconds for the URL request.

Returns:
`np.ndarray`: A numpy artay representing the audio.
"""
requires_backends(load_audio, ["librosa"])

if isinstance(audio, str):
# Load audio from URL (e.g https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/translate_to_chinese.wav)
if audio.startswith("http://") or audio.startswith("https://"):
audio = librosa.load(BytesIO(requests.get(audio, timeout=timeout).content), sr=sampling_rate)[0]
elif os.path.isfile(audio):
audio = librosa.load(audio, sr=sampling_rate)[0]
elif isinstance(audio, np.ndarray):
audio = audio
else:
raise TypeError(
"Incorrect format used for `audio`. Should be an url linking to an audio, a local path, or numpy array."
)
return audio


def hertz_to_mel(freq: Union[float, np.ndarray], mel_scale: str = "htk") -> Union[float, np.ndarray]:
Expand Down
113 changes: 77 additions & 36 deletions src/transformers/models/qwen2_audio/processing_qwen2_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,24 @@
Processor class for Qwen2Audio.
"""

from typing import List, Optional, Union
import warnings
from typing import List, Union

import numpy as np

from ...feature_extraction_utils import BatchFeature
from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils.deprecation import deprecate_kwarg


class Qwen2AudioProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"padding": False,
},
"audio_kwargs": {},
}


class Qwen2AudioProcessor(ProcessorMixin):
Expand Down Expand Up @@ -68,13 +79,15 @@ def __init__(
self.audio_eos_token = tokenizer.audio_eos_token if hasattr(tokenizer, "audio_eos_token") else audio_eos_token
super().__init__(feature_extractor, tokenizer, chat_template=chat_template)

@deprecate_kwarg("audios", version="4.54.0", new_name="audio")
def __call__(
self,
images=None,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
audios: Union[np.ndarray, List[np.ndarray]] = None,
padding: Union[bool, str, PaddingStrategy] = False,
sampling_rate: Optional[int] = None,
**kwargs,
audio: Union[np.ndarray, List[np.ndarray]] = None,
videos=None,
audios=None, # kept for BC
**kwargs: Unpack[Qwen2AudioProcessorKwargs],
) -> BatchFeature:
"""
Main method to prepare for the model one or several sequences(s) and audio(s). This method forwards the `text`
Expand All @@ -88,43 +101,71 @@ def __call__(
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
audios (`np.ndarray`, `List[np.ndarray]`):
audio (`np.ndarray`, `List[np.ndarray]`):
The audio or batch of audios to be prepared. Each audio can be a NumPy array.
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding
index) among:
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
acceptable input length for the model if that argument is not provided.
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
lengths).
sampling_rate (`int`, defaults to 16000):
The sampling rate at which the audio files should be digitalized expressed in hertz (Hz).
"""

# Handle BC when user passes positional args and `audio` gets assigned the second argument
arguments_passed_correctly = False
# case 1: when processor("hello", optional[my_audio])
if images is not None and audio is None and audios is None:
if text is not None:
audio = text
text = images
else:
text = images
# case 2: when processor("hello", audios=my_audio)
elif images is not None and audios is not None and text is None:
audio = audios
text = images
# case 3: when processor(text="hello", audios=my_audio)
elif text is not None and audios is not None and audio is None and images is None:
audio = audios
# case 4: when processor(text="hello", audio=my_audio), the only correct way to pass args
elif text is not None and audio is not None and audios is None and images is None:
arguments_passed_correctly = True
else:
raise ValueError(
"Could not infer input arguments. It is strongly recommended to pass inputs as keyword arguments "
"with keys `text` and `audio` for correct processing."
)

if not arguments_passed_correctly:
warnings.wanr(
"You may have used the wrong order or keyword for inputs. It is strongly recommended to pass inputs as keyword arguments "
"with keys `audios` and `text`. This behavior will be deprecated and positional arguments will throw error in transformers v4.54 ",
FutureWarning,
)

if text is None:
raise ValueError("You need to specify either a `text` input to process.")
raise ValueError("You need to specify `text` input to process.")
elif isinstance(text, str):
text = [text]
elif not isinstance(text, list) and not isinstance(text[0], str):
raise ValueError("Invalid input text. Please provide a string, or a list of strings")

# ensure we have as much audios as audio tokens
num_audio_tokens = sum(sample.count(self.audio_token) for sample in text)
num_audios = 1 if type(audios) == np.ndarray else len(audios)
if num_audio_tokens != num_audios:
raise ValueError(
f"Found {num_audio_tokens} {self.audio_token} token{'s' if num_audio_tokens > 1 else ''} in provided text but received {num_audios} audio{'s' if num_audios > 1 else ''}"
)
# num_audio_tokens = sum(sample.count(self.audio_token) for sample in text)
# num_audios = 1 if type(audio) == np.ndarray else len(audio)
# if num_audio_tokens != num_audios:
# raise ValueError(
# f"Found {num_audio_tokens} {self.audio_token} token{'s' if num_audio_tokens > 1 else ''} in provided text but received {num_audios} audio{'s' if num_audios > 1 else ''}"
# )

output_kwargs = self._merge_kwargs(
Qwen2AudioProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)

if audios is not None:
audio_inputs = self.feature_extractor(
audios, sampling_rate=sampling_rate, return_attention_mask=True, padding="max_length", **kwargs
)
audio_inputs["feature_attention_mask"] = audio_inputs.pop(
"attention_mask"
) # rename attention_mask to prevent conflicts later on
if audio is not None:
# Some kwargs should not be changed so we can expand text with audio tokens below
output_kwargs["audio_kwargs"]["return_attention_mask"] = True
output_kwargs["audio_kwargs"]["padding"] = "max_length"
audio_inputs = self.feature_extractor(audio, **output_kwargs["audio_kwargs"])

# rename attention_mask to prevent conflicts later on
audio_inputs["feature_attention_mask"] = audio_inputs.pop("attention_mask")

expanded_text = []
audio_lengths = audio_inputs["feature_attention_mask"].sum(-1).tolist()
Expand Down Expand Up @@ -162,9 +203,9 @@ def __call__(
expanded_text.append(sample)
text = expanded_text

inputs = self.tokenizer(text, padding=padding, **kwargs)
inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])

if audios is not None:
if audio is not None:
inputs.update(audio_inputs)

return BatchFeature(data={**inputs})
Expand Down Expand Up @@ -228,7 +269,7 @@ def default_chat_template(self):
"{{ message['content'] }}<|im_end|>\n"
"{% else %}"
"{% for content in message['content'] %}"
"{% if 'audio' in content or 'audio_url' in content %}"
"{% if 'audio' in content or 'audio_url' in content or message['type'] == 'audio' %}"
"{% set audio_count.value = audio_count.value + 1 %}"
"Audio {{ audio_count.value }}: <|audio_bos|><|AUDIO|><|audio_eos|>\n"
"{% elif 'text' in content %}"
Expand Down
31 changes: 23 additions & 8 deletions src/transformers/processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import numpy as np
import typing_extensions

from .audio_utils import load_audio
from .dynamic_module_utils import custom_object_save
from .image_utils import ChannelDimension, is_valid_image, is_vision_available, load_image, load_video

Expand Down Expand Up @@ -379,6 +380,8 @@ class ChatTemplateKwargs(TypedDict, total=False):
The backend to use when loading the video which will be used only when there are videos in the conversation.
Can be any of ["decord", "pyav", "opencv", "torchvision"]. Defaults to "pyav" because it is the only backend
that supports all types of sources to load from.
sampling_rate (`int`, *optional*, defaults to `16_000`):
The sampling rate at which the given audio file should be loaded. Defaults to `16_000`.
"""

tokenize: Optional[bool] = False
Expand All @@ -390,6 +393,7 @@ class ChatTemplateKwargs(TypedDict, total=False):
return_assistant_tokens_mask: Optional[bool] = False
num_frames: Optional[int] = None
video_load_backend: Optional[str] = "pyav"
sampling_rate: Optional[int] = 16_000


class AllKwargsForChatTemplate(
Expand Down Expand Up @@ -920,6 +924,7 @@ class MyProcessingKwargs(ProcessingKwargs, CommonKwargs, TextKwargs, ImagesKwarg
if hasattr(self.tokenizer, modality_key)
else tokenizer_init_kwargs[modality_key]
)
print(modality_key, modality, value)
default_kwargs[modality][modality_key] = value
# now defaults kwargs are updated with the tokenizers defaults.
# pass defaults to output dictionary
Expand Down Expand Up @@ -1215,6 +1220,7 @@ def apply_chat_template(
return_dict = chat_template_kwargs.pop("return_dict")
num_frames = chat_template_kwargs.pop("num_frames")
video_load_backend = chat_template_kwargs.pop("video_load_backend")
sampling_rate = chat_template_kwargs.pop("sampling_rate")

prompt = self.tokenizer.apply_chat_template(
conversation,
Expand All @@ -1228,24 +1234,33 @@ def apply_chat_template(
# we will have to return all processed inputs in a dict
if tokenize:
images, videos = [], []
audios = []
for message in conversation:
visuals = [content for content in message["content"] if content["type"] in ["image", "video"]]
for vision_info in visuals:
if vision_info["type"] == "image":
# Load vidoes and images if exist
multimodals = [
content for content in message["content"] if content["type"] in ["image", "video", "audio"]
]
for dict_info in multimodals:
if dict_info["type"] == "image":
for key in ["image", "url", "path", "base64"]:
if key in vision_info:
images.append(load_image(vision_info[key]))
elif vision_info["type"] == "video":
if key in dict_info:
images.append(load_image(dict_info[key]))
elif dict_info["type"] == "video":
for key in ["video", "url", "path"]:
if key in vision_info:
if key in dict_info:
videos.append(
load_video(vision_info[key], num_frames=num_frames, backend=video_load_backend)
load_video(dict_info[key], num_frames=num_frames, backend=video_load_backend)
)
elif dict_info["type"] == "audio":
for key in ["audio", "url", "path"]:
if key in dict_info:
audios.append(load_audio(dict_info[key], sampling_rate=sampling_rate))

out = self(
text=prompt,
images=images if images else None,
videos=videos if videos else None,
audios=audios if audios else None,
**kwargs,
)
if return_dict:
Expand Down
Loading